"""Product of manifolds.
Lead author: Nicolas Guigui.
"""
import joblib
import geomstats.backend as gs
import geomstats.errors
from geomstats.geometry.manifold import Manifold
from geomstats.geometry.product_riemannian_metric import ProductRiemannianMetric
from geomstats.geometry.riemannian_metric import RiemannianMetric
[docs]class ProductManifold(Manifold):
"""Class for a product of manifolds M_1 x ... x M_n.
In contrast to the classes NFoldManifold, Landmarks, or DiscretizedCurves,
the manifolds M_1, ..., M_n need not be the same, nor of
same dimension, but the list of manifolds needs to be provided.
By default, a point is represented by an array of shape:
[..., dim_1 + ... + dim_n_manifolds]
where n_manifolds is the number of manifolds in the product.
This type of representation is called 'vector'.
Alternatively, a point can be represented by an array of shape:
[..., n_manifolds, dim] if the n_manifolds have same dimension dim.
This type of representation is called `matrix`.
Parameters
----------
manifolds : list
List of manifolds in the product.
default_point_type : str, {'vector', 'matrix'}
Default representation of points.
Optional, default: 'vector'.
n_jobs : int
Number of jobs for parallel computing.
Optional, default: 1.
"""
# FIXME (nguigs): This only works for 1d points
def __init__(
self, manifolds, metrics=None, default_point_type="vector", n_jobs=1, **kwargs
):
geomstats.errors.check_parameter_accepted_values(
default_point_type, "default_point_type", ["vector", "matrix"]
)
self.dims = [manifold.dim for manifold in manifolds]
if metrics is None:
metrics = [manifold.metric for manifold in manifolds]
kwargs.setdefault(
"metric",
ProductRiemannianMetric(
metrics, default_point_type=default_point_type, n_jobs=n_jobs
),
)
dim = sum(self.dims)
if default_point_type == "vector":
shape = (sum([m.shape[0] for m in manifolds]),)
else:
shape = (len(manifolds), *manifolds[0].shape)
super(ProductManifold, self).__init__(
dim=dim,
shape=shape,
default_point_type=default_point_type,
**kwargs,
)
self.manifolds = manifolds
self.n_jobs = n_jobs
@staticmethod
def _get_method(manifold, method_name, metric_args):
return getattr(manifold, method_name)(**metric_args)
def _iterate_over_manifolds(self, func, args, intrinsic=False):
cum_index = (
gs.cumsum(self.dims)[:-1]
if intrinsic
else gs.cumsum([k + 1 for k in self.dims])
)
arguments = {}
float_args = {}
for key, value in args.items():
if not isinstance(value, float):
arguments[key] = gs.split(value, cum_index, axis=-1)
else:
float_args[key] = value
args_list = [
{key: arguments[key][j] for key in arguments}
for j in range(len(self.manifolds))
]
pool = joblib.Parallel(n_jobs=self.n_jobs)
out = pool(
joblib.delayed(self._get_method)(
self.manifolds[i], func, {**args_list[i], **float_args}
)
for i in range(len(self.manifolds))
)
return out
[docs] def belongs(self, point, atol=gs.atol):
"""Test if a point belongs to the manifold.
Parameters
----------
point : array-like, shape=[..., {dim, [n_manifolds, dim_each]}]
Point.
atol : float,
Tolerance.
Returns
-------
belongs : array-like, shape=[...,]
Boolean evaluating if the point belongs to the manifold.
"""
point_type = self.default_point_type
if point_type == "vector":
intrinsic = self.metric.is_intrinsic(point)
belongs = self._iterate_over_manifolds(
"belongs", {"point": point, "atol": atol}, intrinsic
)
belongs = gs.stack(belongs, axis=-1)
else:
belongs = gs.stack(
[
space.belongs(point[..., i, :], atol)
for i, space in enumerate(self.manifolds)
],
axis=-1,
)
belongs = gs.all(belongs, axis=-1)
return belongs
[docs] def regularize(self, point):
"""Regularize the point into the manifold's canonical representation.
Parameters
----------
point : array-like, shape=[..., {dim, [n_manifolds, dim_each]}]
Point to be regularized.
point_type : str, {'vector', 'matrix'}
Representation of point.
Optional, default: None.
Returns
-------
regularized_point : array-like,
shape=[..., {dim, [n_manifolds, dim_each]}]
Point in the manifold's canonical representation.
"""
point_type = self.default_point_type
if point_type == "vector":
intrinsic = self.metric.is_intrinsic(point)
regularized_point = self._iterate_over_manifolds(
"regularize", {"point": point}, intrinsic
)
regularized_point = gs.concatenate(regularized_point, axis=-1)
elif point_type == "matrix":
regularized_point = [
manifold_i.regularize(point[..., i, :])
for i, manifold_i in enumerate(self.manifolds)
]
regularized_point = gs.stack(regularized_point, axis=1)
return regularized_point
[docs] def random_point(self, n_samples=1, bound=1.0):
"""Sample in the product space from the uniform distribution.
Parameters
----------
n_samples : int, optional
Number of samples.
bound : float
Bound of the interval in which to sample for non compact manifolds.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., {dim, [n_manifolds, dim_each]}]
Points sampled on the hypersphere.
"""
point_type = self.default_point_type
geomstats.errors.check_parameter_accepted_values(
point_type, "point_type", ["vector", "matrix"]
)
if point_type == "vector":
data = self.manifolds[0].random_point(n_samples, bound)
if len(self.manifolds) > 1:
for space in self.manifolds[1:]:
samples = space.random_point(n_samples, bound)
data = gs.concatenate([data, samples], axis=-1)
return data
point = [space.random_point(n_samples, bound) for space in self.manifolds]
samples = gs.stack(point, axis=-2)
return samples
[docs] def projection(self, point):
"""Project a point in product embedding manifold on each manifold.
Parameters
----------
point : array-like, shape=[..., {dim, [n_manifolds, dim_each]}]
Point in embedding manifold.
Returns
-------
projected : array-like, shape=[..., {dim, [n_manifolds, dim_each]}]
Projected point.
"""
point_type = self.default_point_type
geomstats.errors.check_parameter_accepted_values(
point_type, "point_type", ["vector", "matrix"]
)
if point_type == "vector":
intrinsic = self.metric.is_intrinsic(point)
projected_point = self._iterate_over_manifolds(
"projection", {"point": point}, intrinsic
)
projected_point = gs.concatenate(projected_point, axis=-1)
elif point_type == "matrix":
projected_point = [
manifold_i.projection(point[..., i, :])
for i, manifold_i in enumerate(self.manifolds)
]
projected_point = gs.stack(projected_point, axis=-2)
return projected_point
[docs] def to_tangent(self, vector, base_point):
"""Project a vector to a tangent space of the manifold.
The tangent space of the product manifold is the direct sum of
tangent spaces.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at base point.
"""
point_type = self.default_point_type
geomstats.errors.check_parameter_accepted_values(
point_type, "point_type", ["vector", "matrix"]
)
if point_type == "vector":
intrinsic = self.metric.is_intrinsic(base_point)
tangent_vec = self._iterate_over_manifolds(
"to_tangent", {"base_point": base_point, "vector": vector}, intrinsic
)
tangent_vec = gs.concatenate(tangent_vec, axis=-1)
elif point_type == "matrix":
tangent_vec = [
manifold_i.to_tangent(vector[..., i, :], base_point[..., i, :])
for i, manifold_i in enumerate(self.manifolds)
]
tangent_vec = gs.stack(tangent_vec, axis=-2)
return tangent_vec
[docs] def is_tangent(self, vector, base_point, atol=gs.atol):
"""Check whether the vector is tangent at base_point.
The tangent space of the product manifold is the direct sum of
tangent spaces.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
atol : float
Absolute tolerance.
Optional, default: backend atol.
Returns
-------
is_tangent : bool
Boolean denoting if vector is a tangent vector at the base point.
"""
point_type = self.default_point_type
geomstats.errors.check_parameter_accepted_values(
point_type, "point_type", ["vector", "matrix"]
)
if point_type == "vector":
intrinsic = self.metric.is_intrinsic(base_point)
is_tangent = self._iterate_over_manifolds(
"is_tangent",
{"base_point": base_point, "vector": vector, "atol": atol},
intrinsic,
)
is_tangent = gs.stack(is_tangent, axis=-1)
else:
is_tangent = gs.stack(
[
space.is_tangent(
vector[..., i, :], base_point[..., i, :], atol=atol
)
for i, space in enumerate(self.manifolds)
],
axis=-1,
)
is_tangent = gs.all(is_tangent, axis=-1)
return is_tangent
[docs]class NFoldManifold(Manifold):
r"""Class for an n-fold product manifold :math:`M^n`.
Define a manifold as the product manifold of n copies of a given base manifold M.
Parameters
----------
base_manifold : Manifold
Base manifold.
n_copies : int
Number of replication of the base manifold.
metric : RiemannianMetric
Metric object to use on the manifold.
default_point_type : str, {\'vector\', \'matrix\'}
Point type.
Optional, default: 'vector'.
default_coords_type : str, {\'intrinsic\', \'extrinsic\', etc}
Coordinate type.
Optional, default: 'intrinsic'.
"""
def __init__(
self,
base_manifold,
n_copies,
metric=None,
default_point_type="matrix",
default_coords_type="intrinsic",
**kwargs
):
geomstats.errors.check_integer(n_copies, "n_copies")
dim = n_copies * base_manifold.dim
shape = (n_copies,) + base_manifold.shape
super(NFoldManifold, self).__init__(
dim=dim,
shape=shape,
default_point_type=default_point_type,
default_coords_type=default_coords_type,
**kwargs,
)
self.base_manifold = base_manifold
self.base_shape = base_manifold.shape
self.n_copies = n_copies
if metric is None:
self.metric = NFoldMetric(base_manifold.metric, n_copies)
[docs] def belongs(self, point, atol=gs.atol):
"""Test if a point belongs to the manifold.
Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point.
atol : float,
Tolerance.
Returns
-------
belongs : array-like, shape=[..., n_copies, *base_shape]
Boolean evaluating if the point belongs to the manifold.
"""
point_ = gs.reshape(point, (-1, *self.base_shape))
each_belongs = self.base_manifold.belongs(point_, atol=atol)
reshaped = gs.reshape(each_belongs, (-1, self.n_copies))
return gs.squeeze(gs.all(reshaped, axis=1))
[docs] def is_tangent(self, vector, base_point, atol=gs.atol):
"""Check whether the vector is tangent at base_point.
The tangent space of the product manifold is the direct sum of
tangent spaces.
Parameters
----------
vector : array-like, shape=[..., n_copies, *base_shape]
Vector.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
atol : float
Absolute tolerance.
Optional, default: backend atol.
Returns
-------
is_tangent : bool
Boolean denoting if vector is a tangent vector at the base point.
"""
vector_, point_ = gs.broadcast_arrays(vector, base_point)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_ = gs.reshape(vector_, (-1, *self.base_shape))
each_tangent = self.base_manifold.is_tangent(vector_, point_)
reshaped = gs.reshape(each_tangent, (-1, self.n_copies))
return gs.all(reshaped, axis=1)
[docs] def to_tangent(self, vector, base_point):
"""Project a vector to a tangent space of the manifold.
The tangent space of the product manifold is the direct sum of
tangent spaces.
Parameters
----------
vector : array-like, shape=[..., n_copies, *base_shape]
Vector.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at base point.
"""
vector_, point_ = gs.broadcast_arrays(vector, base_point)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_ = gs.reshape(vector_, (-1, *self.base_shape))
each_tangent = self.base_manifold.to_tangent(vector_, point_)
reshaped = gs.reshape(each_tangent, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)
[docs] def random_point(self, n_samples=1, bound=1.0):
"""Sample in the product space from the uniform distribution.
Parameters
----------
n_samples : int, optional
Number of samples.
bound : float
Bound of the interval in which to sample for non compact manifolds.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., n_copies, *base_shape]
Points sampled on the product manifold.
"""
sample = self.base_manifold.random_point(n_samples * self.n_copies, bound)
reshaped = gs.reshape(sample, (n_samples, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)
[docs] def projection(self, point):
"""Project a point from product embedding manifold to the product manifold.
Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point in embedding manifold.
Returns
-------
projected : array-like, shape=[..., n_copies, *base_shape]
Projected point.
"""
if hasattr(self.base_manifold, "projection"):
point_ = gs.reshape(point, (-1, *self.base_shape))
projected = self.base_manifold.projection(point_)
reshaped = gs.reshape(projected, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)
raise NotImplementedError(
"The base manifold does not implement a projection " "method."
)
[docs]class NFoldMetric(RiemannianMetric):
r"""Class for an n-fold product manifold :math:`M^n`.
Define a manifold as the product manifold of n copies of a given base manifold M.
Parameters
----------
base_metric : RiemannianMetric
Base metric.
n_copies : int
Number of replication of the base metric.
"""
def __init__(self, base_metric, n_copies):
geomstats.errors.check_integer(n_copies, "n_copies")
dim = n_copies * base_metric.dim
base_shape = base_metric.shape
super(NFoldMetric, self).__init__(dim=dim, shape=(n_copies, *base_shape))
self.base_shape = base_shape
self.base_metric = base_metric
self.n_copies = n_copies
[docs] def metric_matrix(self, base_point=None):
"""Compute the matrix of the inner-product.
Matrix of the inner-product defined by the Riemmanian metric
at point base_point of the manifold.
Parameters
----------
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold at which to compute the inner-product matrix.
Optional, default: None.
Returns
-------
matrix : array-like, shape=[..., n_copies, dim, dim]
Matrix of the inner-product at the base point.
"""
point_ = gs.reshape(base_point, (-1, *self.base_shape))
matrices = self.base_metric.metric_matrix(point_)
dim = self.base_metric.dim
reshaped = gs.reshape(matrices, (-1, self.n_copies, dim, dim))
return gs.squeeze(reshaped)
[docs] def inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
"""Compute the inner-product of two tangent vectors at a base point.
Inner product defined by the Riemannian metric at point `base_point`
between tangent vectors `tangent_vec_a` and `tangent_vec_b`.
Parameters
----------
tangent_vec_a : array-like, shape=[..., n_copies, *base_shape]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., n_copies, *base_shape]
Second tangent vector at base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.
Returns
-------
inner_prod : array-like, shape=[...,]
Inner-product of the two tangent vectors.
"""
tangent_vec_a_, tangent_vec_b_, point_ = gs.broadcast_arrays(
tangent_vec_a, tangent_vec_b, base_point
)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_a = gs.reshape(tangent_vec_a_, (-1, *self.base_shape))
vector_b = gs.reshape(tangent_vec_b_, (-1, *self.base_shape))
inner_each = self.base_metric.inner_product(vector_a, vector_b, point_)
reshaped = gs.reshape(inner_each, (-1, self.n_copies))
return gs.squeeze(gs.sum(reshaped, axis=-1))
[docs] def exp(self, tangent_vec, base_point, **kwargs):
"""Compute the Riemannian exponential of a tangent vector.
Parameters
----------
tangent_vec : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at a base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.
Returns
-------
exp : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold equal to the Riemannian exponential
of tangent_vec at the base point.
"""
tangent_vec, point_ = gs.broadcast_arrays(tangent_vec, base_point)
point_ = gs.reshape(point_, (-1, *self.base_shape))
vector_ = gs.reshape(tangent_vec, (-1, *self.base_shape))
each_exp = self.base_metric.exp(vector_, point_)
reshaped = gs.reshape(each_exp, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)
[docs] def log(self, point, base_point, **kwargs):
"""Compute the Riemannian logarithm of a point.
Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.
Returns
-------
log : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
"""
point_, base_point_ = gs.broadcast_arrays(point, base_point)
base_point_ = gs.reshape(base_point_, (-1, *self.base_shape))
point_ = gs.reshape(point_, (-1, *self.base_shape))
each_log = self.base_metric.log(point_, base_point_)
reshaped = gs.reshape(each_log, (-1, self.n_copies) + self.base_shape)
return gs.squeeze(reshaped)