# Source code for geomstats.geometry.nfold_manifold

```"""N-fold product manifold.

Lead author: Nicolas Guigui, John Harvey.
"""

import geomstats.backend as gs
import geomstats.errors
from geomstats.geometry.manifold import Manifold
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.vectorization import get_batch_shape

[docs]
class NFoldManifold(Manifold):
r"""Class for an n-fold product manifold :math:`M^n`.

Define a manifold as the product manifold of n copies of a given base
manifold M.

Parameters
----------
base_manifold : Manifold
Base manifold.
n_copies : int
Number of replication of the base manifold.
"""

def __init__(
self,
base_manifold,
n_copies,
equip=True,
):
geomstats.errors.check_integer(n_copies, "n_copies")
dim = n_copies * base_manifold.dim
shape = (n_copies,) + base_manifold.shape

self.base_manifold = base_manifold
self.n_copies = n_copies

super().__init__(
dim=dim,
shape=shape,
intrinsic=base_manifold.intrinsic,
equip=equip,
)

[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return NFoldMetric

[docs]
def belongs(self, point, atol=gs.atol):
"""Test if a point belongs to the manifold.

Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point.
atol : float,
Tolerance.

Returns
-------
belongs : array-like, shape=[..., n_copies, *base_shape]
Boolean evaluating if the point belongs to the manifold.
"""
batch_shape = get_batch_shape(self.point_ndim, point)
point_ = gs.reshape(point, (-1, *self.base_manifold.shape))

each_belongs = self.base_manifold.belongs(point_, atol=atol)

reshaped = gs.reshape(each_belongs, batch_shape + (self.n_copies,))
return gs.all(reshaped, axis=-1)

[docs]
def is_tangent(self, vector, base_point, atol=gs.atol):
"""Check whether the vector is tangent at base_point.

The tangent space of the product manifold is the direct sum of
tangent spaces.

Parameters
----------
vector : array-like, shape=[..., n_copies, *base_shape]
Vector.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
atol : float
Absolute tolerance.
Optional, default: backend atol.

Returns
-------
is_tangent : bool
Boolean denoting if vector is a tangent vector at the base point.
"""
batch_shape = get_batch_shape(self.point_ndim, vector_)
base_point_ = gs.reshape(base_point_, (-1, *self.base_manifold.shape))
vector_ = gs.reshape(vector_, (-1, *self.base_manifold.shape))

each_tangent = self.base_manifold.is_tangent(vector_, base_point_, atol=atol)

reshaped = gs.reshape(each_tangent, batch_shape + (self.n_copies,))
return gs.all(reshaped, axis=-1)

[docs]
def to_tangent(self, vector, base_point):
"""Project a vector to a tangent space of the manifold.

The tangent space of the product manifold is the direct sum of
tangent spaces.

Parameters
----------
vector : array-like, shape=[..., n_copies, *base_shape]
Vector.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.

Returns
-------
tangent_vec : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at base point.
"""
base_point_ = gs.reshape(base_point_, (-1, *self.base_manifold.shape))
batch_shape = get_batch_shape(self.point_ndim, vector_)
vector_ = gs.reshape(vector_, (-1, *self.base_manifold.shape))

each_tangent = self.base_manifold.to_tangent(vector_, base_point_)

return gs.reshape(each_tangent, batch_shape + self.shape)

[docs]
def random_point(self, n_samples=1, bound=1.0):
"""Sample in the product space from the product distribution.

The distribution used is the product of the distributions that each copy of the
manifold uses in its own random_point method.

Parameters
----------
n_samples : int, optional
Number of samples.
bound : float
Bound of the interval in which to sample for non compact manifolds.
Optional, default: 1.

Returns
-------
samples : array-like, shape=[..., n_copies, *base_shape]
Points sampled on the product manifold.
"""
sample = self.base_manifold.random_point(n_samples * self.n_copies, bound)
reshaped = gs.reshape(
sample, (n_samples, self.n_copies) + self.base_manifold.shape
)
if n_samples > 1:
return reshaped
return gs.squeeze(reshaped, axis=0)

[docs]
def projection(self, point):
"""Project a point from product embedding manifold to the product manifold.

Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point in embedding manifold.

Returns
-------
projected : array-like, shape=[..., n_copies, *base_shape]
Projected point.
"""
if hasattr(self.base_manifold, "projection"):
batch_shape = get_batch_shape(self.point_ndim, point)
point_ = gs.reshape(point, (-1, *self.base_manifold.shape))
projected = self.base_manifold.projection(point_)
return gs.reshape(projected, batch_shape + self.shape)
raise NotImplementedError(
"The base manifold does not implement a projection method."
)

[docs]
class NFoldMetric(RiemannianMetric):
r"""Class for an n-fold product manifold :math:`M^n`.

Define a manifold as the product manifold of n copies of a given base
manifold M.

Parameters
----------
space : NFoldManifold
Base space.
scales : array-like
Scale of each metric in the product.
"""

def __init__(self, space, scales=None, signature=None):
if scales is not None:
for scale in scales:
geomstats.errors.check_positive(scale, "Each value in scales")

if len(scales) != space.n_copies:
raise ValueError(
"Number of scales should be equal to number of factors"
)
self.scales = scales

super().__init__(space=space, signature=signature)

[docs]
def metric_matrix(self, base_point):
"""Compute the matrix of the inner-product.

Matrix of the inner-product defined by the Riemmanian metric
at point base_point of the manifold.

Parameters
----------
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold at which to compute the inner-product matrix.

Returns
-------
matrix : array-like, shape=[..., n_copies, dim, dim]
Matrix of the inner-product at the base point.
"""
base_manifold = self._space.base_manifold
batch_shape = get_batch_shape(self._space.point_ndim, base_point)

point_ = gs.reshape(base_point, (-1, *base_manifold.shape))
matrices = base_manifold.metric.metric_matrix(point_)

dim = base_manifold.shape[-1]
reshaped = gs.reshape(matrices, batch_shape + (self._space.n_copies, dim, dim))

if self.scales is not None:
reshaped = gs.einsum("j,...jkl->...jkl", self.scales, reshaped)

return reshaped

[docs]
def pointwise_inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
"""Pointwise inner product.

Parameters
----------
tangent_vec_a : array-like, shape=[..., n_copies, *base_shape]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., n_copies, *base_shape]
Second tangent vector at base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
pointwise_inner_prod : array-like, shape=[..., n_copies]
Inner-product of the two tangent vectors.
"""
base_manifold = self._space.base_manifold

tangent_vec_a, tangent_vec_b, base_point
)
batch_shape = get_batch_shape(self._space.point_ndim, tangent_vec_a_)

point_ = gs.reshape(point_, (-1, *base_manifold.shape))
vector_a = gs.reshape(tangent_vec_a_, (-1, *base_manifold.shape))
vector_b = gs.reshape(tangent_vec_b_, (-1, *base_manifold.shape))
inner_each = base_manifold.metric.inner_product(vector_a, vector_b, point_)

reshaped = gs.reshape(inner_each, batch_shape + (self._space.n_copies,))

if self.scales is not None:
reshaped = gs.einsum("j,...j->...j", self.scales, reshaped)

return reshaped

[docs]
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
"""Compute the inner-product of two tangent vectors at a base point.

Inner product defined by the Riemannian metric at point `base_point`
between tangent vectors `tangent_vec_a` and `tangent_vec_b`.

Parameters
----------
tangent_vec_a : array-like, shape=[..., n_copies, *base_shape]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., n_copies, *base_shape]
Second tangent vector at base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
inner_prod : array-like, shape=[...,]
Inner-product of the two tangent vectors.
"""
return gs.sum(
self.pointwise_inner_product(tangent_vec_a, tangent_vec_b, base_point),
axis=-1,
)

[docs]
def pointwise_norm(self, tangent_vec, base_point):
"""Compute the pointwise norms of a tangent vector.

Compute the norms of the components of a tangent vector at the different
sampling points of a base curve.

Parameters
----------
tangent_vec : array-like, shape=[..., n_copies, *base_shape]
Tangent vector to discrete curve.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point representing a discrete curve.

Returns
-------
norm : array-like, shape=[..., *base_shape]
Point-wise norms.
"""
sq_norm = self.pointwise_inner_product(
tangent_vec_a=tangent_vec, tangent_vec_b=tangent_vec, base_point=base_point
)
return gs.sqrt(sq_norm)

[docs]
def exp(self, tangent_vec, base_point):
"""Compute the Riemannian exponential of a tangent vector.

Parameters
----------
tangent_vec : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at a base point.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
exp : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold equal to the Riemannian exponential
of tangent_vec at the base point.
"""
base_manifold = self._space.base_manifold
batch_shape = get_batch_shape(self._space.point_ndim, tangent_vec, base_point)

point_ = gs.reshape(point_, (-1, *base_manifold.shape))
vector_ = gs.reshape(tangent_vec, (-1, *base_manifold.shape))
each_exp = base_manifold.metric.exp(vector_, point_)
return gs.reshape(each_exp, batch_shape + self._space.shape)

[docs]
def log(self, point, base_point):
"""Compute the Riemannian logarithm of a point.

Parameters
----------
point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
base_point : array-like, shape=[..., n_copies, *base_shape]
Point on the manifold.
Optional, default: None.

Returns
-------
log : array-like, shape=[..., n_copies, *base_shape]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
"""
base_manifold = self._space.base_manifold
batch_shape = get_batch_shape(self._space.point_ndim, point, base_point)