Source code for geomstats.learning.preprocessing
"""Transformer for manifold-valued data.
Lead author: Nicolas Guigui.
"""
from sklearn.base import BaseEstimator, TransformerMixin
import geomstats.backend as gs
from geomstats.geometry.matrices import Matrices
from geomstats.geometry.skew_symmetric_matrices import SkewSymmetricMatrices
from geomstats.geometry.symmetric_matrices import SymmetricMatrices
from geomstats.learning.exponential_barycenter import ExponentialBarycenter
from geomstats.learning.frechet_mean import FrechetMean
[docs]
class ToTangentSpace(BaseEstimator, TransformerMixin):
"""Lift data to a tangent space.
Compute the logs of all data points and reshape them to
1d vectors if necessary. This means that all the data points, that belong
to a possibly non-linear manifold are lifted to one of the tangent space of
the manifold, which is a vector space. By default, the mean of the data
is computed (with the FrechetMean or the ExponentialBarycenter estimator,
as appropriate) and the tangent space at the mean is used. Any other base
point can be passed. The data points are then represented by the initial
velocities of the geodesics that lead from base_point to each data point.
Any machine learning algorithm can then be used with the output array.
Parameters
----------
space : Manifold
Equipped manifold or unequipped space implementing `exp` and `log`.
Notes
-----
* Required geometry methods: `log`, `exp`.
"""
def __init__(self, space):
self.space = space
if hasattr(self.space, "metric"):
self.mean_estimator = FrechetMean(space)
else:
self.mean_estimator = ExponentialBarycenter(space)
self.base_point_ = None
@property
def _geometry(self):
"""Object where `exp` and `log` are defined."""
if hasattr(self.space, "metric"):
return self.space.metric
return self.space
[docs]
def fit(self, X, y=None, weights=None, base_point=None):
"""Compute the central point at which to take the log.
This method is only used if `base_point=None` to compute the mean of
the input data.
Parameters
----------
X : array-like, shape=[n_samples, {dim, [n, n]}]
The training input samples.
y : None
Ignored.
weights : array-like, shape=[n_samples, 1]
Weights associated to the points.
Optional, default: None
base_point : array-like, shape=[{dim, [n, n]}]
Point similar to the input data from which to compute the logs.
Optional, default: None.
Returns
-------
self : object
Returns self.
"""
if base_point is None:
self.base_point_ = self.mean_estimator.fit(X, y, weights).estimate_
else:
self.base_point_ = base_point
return self
[docs]
def transform(self, X):
"""Lift data to a tangent space.
Compute the logs of all data point and reshapes them to
1d vectors if necessary. By default the logs are taken at the mean
but any other base point can be passed. Any machine learning
algorithm can then be used with the output array.
Parameters
----------
X : array-like, shape=[n_samples, {dim, [n, n]}]
Data to transform.
Returns
-------
X_new : array-like, shape=[n_samples, dim]
Lifted data.
"""
if self.base_point_ is None:
raise Exception("Not fitted")
tangent_vecs = self._geometry.log(X, base_point=self.base_point_)
if self.space.point_ndim == 1:
return tangent_vecs
if gs.all(Matrices.is_symmetric(tangent_vecs)):
X = SymmetricMatrices.basis_representation(tangent_vecs)
elif gs.all(Matrices.is_skew_symmetric(tangent_vecs)):
X = SkewSymmetricMatrices(tangent_vecs.shape[-1]).basis_representation(
tangent_vecs
)
else:
X = gs.reshape(tangent_vecs, (len(X), -1))
return X
[docs]
def inverse_transform(self, X):
"""Reconstruction of X.
The reconstruction will match X_original whose transform would be X.
Parameters
----------
X : array-like, shape=[n_samples, dim]
New data, where dim is the dimension of the manifold data belong
to.
Returns
-------
X_original : array-like, shape=[n_samples, {dim, [n, n]}
Data lying on the manifold.
"""
if self.base_point_ is None:
raise Exception("Not fitted")
if self.space.point_ndim > 1:
n_base_point = self.base_point_.shape[-1]
n_vecs = X.shape[-1]
dim_sym = int(n_base_point * (n_base_point + 1) / 2)
dim_skew = int(n_base_point * (n_base_point - 1) / 2)
if gs.all(Matrices.is_symmetric(self.base_point_)) and dim_sym == n_vecs:
tangent_vecs = SymmetricMatrices(
self.base_point_.shape[-1]
).matrix_representation(X)
elif dim_skew == n_vecs:
tangent_vecs = SkewSymmetricMatrices(dim_skew).matrix_representation(X)
else:
dim = self.base_point_.shape[-1]
tangent_vecs = gs.reshape(X, (len(X), dim, dim))
else:
tangent_vecs = X
return self._geometry.exp(tangent_vecs, self.base_point_)