Source code for geomstats.learning._template

"""Module exposing estimator base class."""

from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.metrics import euclidean_distances
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y

import geomstats.backend as gs


[docs] class TemplateEstimator(BaseEstimator): """A template estimator to be used as a reference implementation. For more information regarding how to build your own estimator, read more in the :ref:`User Guide <user_guide>`. Parameters ---------- demo_param : str, default='demo_param' A parameter used for demonstation of how to pass and store paramters. """ def __init__(self, demo_param="demo_param"): self.demo_param = demo_param
[docs] def fit(self, X, y): """Train estimator on labeled data. Parameters ---------- X : {array-like, sparse matrix}, shape (n_samples, n_features) The training input samples. y : array-like, shape (n_samples,) or (n_samples, n_outputs) The target values (class labels in classification, real numbers in regression). Returns ------- self : object Returns self. """ X, y = check_X_y(X, y, accept_sparse=True) self.is_fitted_ = True # `fit` should always return `self` return self
[docs] def predict(self, X): """Perform prediction. Parameters ---------- X : {array-like, sparse matrix}, shape (n_samples, n_features) The training input samples. Returns ------- y : ndarray, shape (n_samples,) Returns an array of ones. """ X = check_array(X, accept_sparse=True) check_is_fitted(self, "is_fitted_") return gs.ones(X.shape[0], dtype=gs.int64)
[docs] class TemplateClassifier(BaseEstimator, ClassifierMixin): """An example classifier which implements a 1-NN algorithm. For more information regarding how to build your own classifier, read more in the :ref:`User Guide <user_guide>`. Parameters ---------- demo_param : str, default='demo' A parameter used for demonstation of how to pass and store paramters. Attributes ---------- X_ : ndarray, shape (n_samples, n_features) The input passed during :meth:`fit`. y_ : ndarray, shape (n_samples,) The labels passed during :meth:`fit`. classes_ : ndarray, shape (n_classes,) The classes seen at :meth:`fit`. """ def __init__(self, demo_param="demo"): self.demo_param = demo_param
[docs] def fit(self, X, y): """Train classifier on labeled data. Parameters ---------- X : array-like, shape (n_samples, n_features) The training input samples. y : array-like, shape (n_samples,) The target values. An array of int. Returns ------- self : object Returns self. """ # Check that X and y have correct shape X, y = check_X_y(X, y) # Store the classes seen during fit self.classes_ = unique_labels(y) self.X_ = X self.y_ = y # Return the classifier return self
[docs] def predict(self, X): """Classify input data. Parameters ---------- X : array-like, shape (n_samples, n_features) The input samples. Returns ------- y : ndarray, shape (n_samples,) The label for each sample is the label of the closest sample seen during fit. """ # Check is fit had been called check_is_fitted(self, ["X_", "y_"]) # Input validation X = check_array(X) closest = gs.argmin(euclidean_distances(X, self.X_), axis=1) return self.y_[closest]
[docs] class TemplateTransformer(BaseEstimator, TransformerMixin): """An example transformer that returns the element-wise square root. For more information regarding how to build your own transformer, read more in the :ref:`User Guide <user_guide>`. Parameters ---------- demo_param : str, default='demo' A parameter used for demonstation of how to pass and store paramters. Attributes ---------- n_features_ : int The number of features of the data passed to :meth:`fit`. """ def __init__(self, demo_param="demo"): self.demo_param = demo_param
[docs] def fit(self, X, y=None): """Train function for a transformer. Parameters ---------- X : {array-like, sparse matrix}, shape (n_samples, n_features) The training input samples. y : None There is no need of a target in a transformer, yet the pipeline API requires this parameter. Returns ------- self : object Returns self. """ X = check_array(X, accept_sparse=True) self.n_features_ = X.shape[1] # Return the transformer return self
[docs] def transform(self, X): """Transform input data. Parameters ---------- X : {array-like, sparse-matrix}, shape (n_samples, n_features) The input samples. Returns ------- X_transformed : array, shape (n_samples, n_features) The array containing the element-wise square roots of the values in ``X``. """ # Check is fit had been called check_is_fitted(self, "n_features_") # Input validation X = check_array(X, accept_sparse=True) # Check that the input is of the same shape as the one passed # during fit. if X.shape[1] != self.n_features_: raise ValueError( "Shape of input is different from what was seen" "in `fit`" ) return gs.sqrt(X)