"""Module exposing estimator base class."""
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.metrics import euclidean_distances
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
import geomstats.backend as gs
[docs]
class TemplateEstimator(BaseEstimator):
"""A template estimator to be used as a reference implementation.
For more information regarding how to build your own estimator, read more
in the :ref:`User Guide <user_guide>`.
Parameters
----------
demo_param : str, default='demo_param'
A parameter used for demonstation of how to pass and store paramters.
"""
def __init__(self, demo_param="demo_param"):
self.demo_param = demo_param
[docs]
def fit(self, X, y):
"""Train estimator on labeled data.
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,) or (n_samples, n_outputs)
The target values (class labels in classification, real numbers in
regression).
Returns
-------
self : object
Returns self.
"""
X, y = check_X_y(X, y, accept_sparse=True)
self.is_fitted_ = True
# `fit` should always return `self`
return self
[docs]
def predict(self, X):
"""Perform prediction.
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
The training input samples.
Returns
-------
y : ndarray, shape (n_samples,)
Returns an array of ones.
"""
X = check_array(X, accept_sparse=True)
check_is_fitted(self, "is_fitted_")
return gs.ones(X.shape[0], dtype=gs.int64)
[docs]
class TemplateClassifier(BaseEstimator, ClassifierMixin):
"""An example classifier which implements a 1-NN algorithm.
For more information regarding how to build your own classifier, read more
in the :ref:`User Guide <user_guide>`.
Parameters
----------
demo_param : str, default='demo'
A parameter used for demonstation of how to pass and store paramters.
Attributes
----------
X_ : ndarray, shape (n_samples, n_features)
The input passed during :meth:`fit`.
y_ : ndarray, shape (n_samples,)
The labels passed during :meth:`fit`.
classes_ : ndarray, shape (n_classes,)
The classes seen at :meth:`fit`.
"""
def __init__(self, demo_param="demo"):
self.demo_param = demo_param
[docs]
def fit(self, X, y):
"""Train classifier on labeled data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The training input samples.
y : array-like, shape (n_samples,)
The target values. An array of int.
Returns
-------
self : object
Returns self.
"""
# Check that X and y have correct shape
X, y = check_X_y(X, y)
# Store the classes seen during fit
self.classes_ = unique_labels(y)
self.X_ = X
self.y_ = y
# Return the classifier
return self
[docs]
def predict(self, X):
"""Classify input data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The input samples.
Returns
-------
y : ndarray, shape (n_samples,)
The label for each sample is the label of the closest sample
seen during fit.
"""
# Check is fit had been called
check_is_fitted(self, ["X_", "y_"])
# Input validation
X = check_array(X)
closest = gs.argmin(euclidean_distances(X, self.X_), axis=1)
return self.y_[closest]