"""The MDM classifier on manifolds.
Lead authors: Daniel Brooks and Quentin Barthelemy.
"""
from scipy.special import softmax
from sklearn.metrics import accuracy_score
import geomstats.backend as gs
from geomstats.learning.frechet_mean import FrechetMean
[docs]class RiemannianMinimumDistanceToMeanClassifier:
r"""Minimum Distance to Mean (MDM) classifier on manifolds.
Classification by nearest centroid. For each of the given classes, a
centroid is estimated according to the chosen metric. Then, for each new
point, the class is affected according to the nearest centroid (see
[BBCJ2012]_).
Parameters
----------
riemannian_metric : RiemannianMetric
Riemannian metric to be used.
n_classes : int
Number of classes.
point_type : str, {\'vector\', \'matrix\'}
Point type.
Optional, default: \'matrix\'.
Attributes
----------
mean_estimates_ : list
If fit, centroids computed on training set.
classes_ : list
If fit, classes of training set.
References
----------
.. [BBCJ2012] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, Multiclass
Brain-Computer Interface Classification by Riemannian Geometry. IEEE
Trans. Biomed. Eng., vol. 59, pp. 920-928, 2012.
"""
def __init__(self, riemannian_metric, n_classes, point_type="matrix"):
self.riemannian_metric = riemannian_metric
self.n_classes = n_classes
self.point_type = point_type
self.mean_estimates_ = None
self.classes_ = None
[docs] def fit(self, X, y):
"""Compute Frechet mean of each class.
Parameters
----------
X : array-like, shape=[n_samples, dim]
if point_type='vector'
shape=[n_samples, n, n]
if point_type='matrix'
Training data, where n_samples is the number of samples
and n_features is the number of features.
y : array-like, shape=[n_samples,]
Training labels.
"""
self.classes_ = gs.unique(y)
mean_estimator = FrechetMean(
metric=self.riemannian_metric, point_type=self.point_type
)
frechet_means = []
for c in self.classes_:
X_c = X[gs.where(y == c, True, False)]
frechet_means.append(mean_estimator.fit(X_c).estimate_)
self.mean_estimates_ = gs.array(frechet_means)
[docs] def predict(self, X):
"""Compute closest neighbor according to riemannian_metric.
Parameters
----------
X : array-like, shape=[n_samples, dim]
if point_type='vector'
shape=[n_samples, n, n]
if point_type='matrix'
Test data, where n_samples is the number of samples
and n_features is the number of features.
Returns
-------
y : array-like, shape=[n_samples,]
Predicted labels.
"""
indices = self.riemannian_metric.closest_neighbor_index(X, self.mean_estimates_)
if gs.ndim(indices) == 0:
indices = gs.expand_dims(indices, 0)
return gs.take(self.classes_, indices)
[docs] def predict_proba(self, X):
"""Compute probabilities.
Compute probabilities to belong to classes according to
riemannian_metric.
Parameters
----------
X : array-like, shape=[n_samples, dim]
if point_type='vector'
shape=[n_samples, n, n]
if point_type='matrix'
Test data, where n_samples is the number of samples
and n_features is the number of features.
Returns
-------
probas : array-like, shape=[n_samples, n_classes]
Probability of the sample for each class in the model.
"""
n_samples = X.shape[0]
probas = []
for i in range(n_samples):
dist2 = self.riemannian_metric.squared_dist(X[i], self.mean_estimates_)
probas.append(softmax(-dist2))
return gs.array(probas)
[docs] def score(self, X, y, weights=None):
"""Compute score on the given test data and labels.
Compute the score defined as accuracy.
Parameters
----------
X : array-like, shape=[n_samples, dim]
if point_type='vector'
shape=[n_samples, n, n]
if point_type='matrix'
Test data, where n_samples is the number of samples
and n_features is the number of features.
y : array-like, shape=[n_samples,]
True labels for `X`.
weights : array-like, shape=[n_samples,]
Weights associated to the samples.
Optional, default: None.
Returns
-------
score : float
Mean accuracy of ``self.predict(X)`` wrt. `y`.
"""
return accuracy_score(y, self.predict(X), sample_weight=weights)