"""Product of manifolds.
Lead author: Nicolas Guigui, John Harvey.
"""
import math
import geomstats.backend as gs
import geomstats.errors
from geomstats.geometry.complex_manifold import ComplexManifold
from geomstats.geometry.complex_riemannian_metric import ComplexRiemannianMetric
from geomstats.geometry.manifold import Manifold
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.vectorization import get_batch_shape
COMPLEX_OBJECTS = (ComplexRiemannianMetric, ComplexManifold)
def _factor_is_complex(factor):
if (
isinstance(factor, COMPLEX_OBJECTS)
or hasattr(factor, "underlying_metric")
and isinstance(factor.underlying_metric, COMPLEX_OBJECTS)
):
return True
return False
def _has_mixed_fields(factors):
bools = [_factor_is_complex(factor) for factor in factors]
if len(set(bools)) == 2:
return True
return False
def _all_equal(arg):
"""Check if all elements of arg are equal."""
return arg.count(arg[0]) == len(arg)
def _block_diagonal(factor_matrices):
"""Put a list of square matrices in block diagonal form."""
shapes_dict = {}
for i, matrix_i in enumerate(factor_matrices):
for j, matrix_j in enumerate(factor_matrices):
shapes_dict[(i, j)] = matrix_i.shape[:-1] + matrix_j.shape[-1:]
rows = []
# concacatenate along axis = -2
for i, matrix_i in enumerate(factor_matrices):
# concatenate along axis = -1
blocks_to_concatenate = []
for j, _ in enumerate(factor_matrices):
if i == j:
blocks_to_concatenate.append(matrix_i)
else:
blocks_to_concatenate.append(gs.zeros(shapes_dict[(i, j)]))
row = gs.concatenate(blocks_to_concatenate, axis=-1)
rows.append(row)
metric_matrix = gs.concatenate(rows, axis=-2)
return metric_matrix
def _find_product_shape(factors, point_ndim):
"""Determine an appropriate shape for the product from the factors."""
factor_shapes = [factor.shape for factor in factors]
if point_ndim is None:
if _all_equal(factor_shapes):
return len(factors), *factors[0].shape
point_ndim = 1
if point_ndim == 1:
return (sum(math.prod(factor_shape) for factor_shape in factor_shapes),)
if not _all_equal(factor_shapes):
raise ValueError(
"A point_ndim greater than one can only be used if all "
"manifolds have shape."
)
return len(factors), *factors[0].shape
class _IterateOverFactorsMixins:
def __init__(
self, factors, cum_index, pool_outputs, has_mixed_fields, *args, **kwargs
):
self.factors = factors
self._cum_index = cum_index
self._pool_outputs = pool_outputs
self._has_mixed_fields = has_mixed_fields
super().__init__(*args, **kwargs)
def embed_to_product(self, points):
"""Map a point in each factor to a point in the product.
Parameters
----------
points : list
A list of points, one from each factor, each array-like of shape
(..., factor.shape)
Returns
-------
point : array-like, shape (..., shape)
Raises
------
ShapeError
If the points are not compatible with the shapes of the corresponding
factors.
"""
for point, factor in zip(points, self.factors):
geomstats.errors.check_point_shape(point, factor)
if self.point_ndim == 1:
points_ = []
for point, factor in zip(points, self.factors):
if gs.ndim(point) > len(factor.shape):
batch_shape = get_batch_shape(factor.point_ndim, point)
point = gs.reshape(point, batch_shape + (-1,))
else:
point = gs.flatten(point)
points_.append(point)
return gs.concatenate(points_, axis=-1)
stacking_axis = -1 * len(self.shape)
return gs.stack(points, axis=stacking_axis)
def project_from_product(self, point):
"""Map a point in the product to points in each factor.
Parameters
----------
point : array-like, shape (..., shape)
The point to be projected to the factors
Returns
-------
projected_points : list of array-like
The points on each factor, of shape (..., factor.shape)
Raises
------
ShapeError
If the point does not have a shape compatible with the product manifold.
"""
geomstats.errors.check_point_shape(point, self)
if self.point_ndim == 1:
projected_points = gs.split(point, self._cum_index, axis=-1)
projected_points = [
self._reshape_trailing(projected_points[j], self.factors[j])
for j in range(len(self.factors))
]
else:
splitting_axis = -1 * len(self.shape)
projected_points = gs.split(point, len(self.factors), axis=splitting_axis)
projected_points = [
gs.squeeze(projected_point, axis=splitting_axis)
for projected_point in projected_points
]
if self._has_mixed_fields:
for i, (factor, projected_point) in enumerate(
zip(self.factors, projected_points)
):
if not _factor_is_complex(factor):
projected_points[i] = gs.real(projected_point)
return projected_points
@staticmethod
def _reshape_trailing(argument, factor):
"""Convert the trailing dimensions to match the shape of a factor manifold."""
space = factor._space if isinstance(factor, RiemannianMetric) else factor
if space.point_ndim == 1:
return argument
leading_shape = argument.shape[:-1]
trailing_shape = space.shape
new_shape = leading_shape + trailing_shape
return gs.reshape(argument, new_shape)
def _iterate_over_factors(self, func, args):
"""Apply a function to each factor of the product.
func is called on each factor of the product.
Array-type arguments are separated out to be passed to func for each factor,
but other arguments are passed unchanged.
Parameters
----------
func : str
The name of a method which is defined for each factor of the product
The method must return an array of shape (..., factor.shape) or a boolean
array of shape (...,).
args : dict
Dict of arguments.
Array-type arguments must be of type (..., shape)
Other arguments are passed to each factor unchanged
Returns
-------
out : array-like, shape = [..., {(), shape}]
"""
# TODO The user may prefer to provide the arguments as lists and receive them as
# TODO lists, as this may be the form in which they are available. This should
# TODO be allowed, rather than packing and unpacking them repeatedly.
args_list, numerical_args = self._validate_and_prepare_args_for_iteration(args)
out = [
self._get_method(self.factors[i], func, args_list[i], numerical_args)
for i in range(len(self.factors))
]
if self._pool_outputs:
return self._pool_outputs_from_function(out)
return out
def _validate_and_prepare_args_for_iteration(self, args):
"""Separate arguments into different types and validate them.
Parameters
----------
args : dict
Dict of arguments.
Float or int arguments are passed to func for each manifold
Array-type arguments must be of type (..., shape)
Returns
-------
arguments : list
List of dicts of arguments with values being array-like.
Each element of the list corresponds to a factor af the manifold.
numerical_args : dict
Dict of non-array arguments
"""
args_list = [{} for _ in self.factors]
numerical_args = {}
for key, value in args.items():
if not gs.is_array(value):
numerical_args[key] = value
else:
new_args = self.project_from_product(value)
for args_dict, new_arg in zip(args_list, new_args):
args_dict[key] = new_arg
return args_list, numerical_args
@staticmethod
def _get_method(factor, method_name, array_args, num_args):
"""Call factor.method_name."""
return getattr(factor, method_name)(**array_args, **num_args)
[docs]
class ProductManifold(_IterateOverFactorsMixins, Manifold):
"""Class for a product of manifolds M_1 x ... x M_n.
In contrast to the classes NFoldManifold, Landmarks, or DiscretizedCurves,
the manifolds M_1, ..., M_n need not be the same, nor of
same dimension, but the list of manifolds needs to be provided.
Parameters
----------
factors : tuple
Collection of manifolds in the product.
point_ndim : int or None
If None, defaults to 1, unless all factors have the same shape.
"""
def __init__(self, factors, point_ndim=None, equip=True):
factors = tuple(factors)
factor_dims = [factor.dim for factor in factors]
dim = sum(factor_dims)
shape = _find_product_shape(factors, point_ndim)
intrinsic = all(factor.intrinsic for factor in factors)
if not intrinsic:
factor_embedding_spaces = [
(
manifold.embedding_space
if hasattr(manifold, "embedding_space")
else manifold
)
for manifold in factors
]
# TODO: need to revisit due to removal of scales
self.embedding_space = ProductManifold(
factor_embedding_spaces, point_ndim, equip=False
)
cum_index = (
gs.cumsum(factor_dims)[:-1]
if intrinsic
else self.embedding_space._cum_index
)
super().__init__(
factors=factors,
cum_index=cum_index,
pool_outputs=True,
has_mixed_fields=_has_mixed_fields(factors),
dim=dim,
shape=shape,
intrinsic=intrinsic,
equip=equip,
)
[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return ProductRiemannianMetric
def _pool_outputs_from_function(self, outputs):
"""Collect outputs for each product to be returned.
If each element of the output is a boolean array of the same shape, test along
the list whether all elements are True and return a boolean array of the same
shape.
Otherwise, if each element of the output has a shape compatible with points of
the corresponding factor, an attempt is made to map the list of points to a
point in the product by embed_to_product.
Parameters
----------
outputs : list
A list of outputs which must be pooled
Returns
-------
pooled_output : array-like, shape {(...,), (..., self.shape)}
"""
# TODO: simplify after cleaning gs.squeeze
all_arrays = gs.all([gs.is_array(factor_output) for factor_output in outputs])
if (
all_arrays
and _all_equal([factor_output.shape for factor_output in outputs])
and gs.all([gs.is_bool(factor_output) for factor_output in outputs])
or (not all_arrays)
):
outputs = gs.stack([gs.array(factor_output) for factor_output in outputs])
outputs = gs.all(outputs, axis=0)
return outputs
try:
return self.embed_to_product(outputs)
except geomstats.errors.ShapeError:
raise RuntimeError(
"Could not combine outputs - they are not points of the individual"
" factors."
)
except ValueError:
raise RuntimeError(
"Could not combine outputs, probably because they could"
" not be concatenated or stacked."
)
[docs]
def belongs(self, point, atol=gs.atol):
"""Test if a point belongs to the manifold.
Parameters
----------
point : array-like, shape=[..., {dim, embedding_space.dim, \
[n_manifolds, dim_each]}]
Point.
atol : float,
Tolerance.
Returns
-------
belongs : array-like, shape=[...,]
Boolean evaluating if the point belongs to the manifold.
"""
belongs = self._iterate_over_factors("belongs", {"point": point, "atol": atol})
return belongs
[docs]
def regularize(self, point):
"""Regularize the point into the manifold's canonical representation.
Parameters
----------
point : array-like, shape=[..., {dim, embedding_space.dim, \
[n_manifolds, dim_each]}]
Point to be regularized.
Returns
-------
regularized_point : array-like, shape=[..., {dim, embedding_space.dim, \
[n_manifolds, dim_each]}]
Point in the manifold's canonical representation.
"""
regularized_point = self._iterate_over_factors("regularize", {"point": point})
return regularized_point
[docs]
def random_point(self, n_samples=1, bound=1.0):
"""Sample in the product space from the product distribution.
The distribution used is the product of the distributions used by the
random_point methods of each individual factor manifold.
Parameters
----------
n_samples : int, optional
Number of samples.
bound : float
Bound of the interval in which to sample for non compact manifolds.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., {dim, embedding_space.dim, \
[n_manifolds, dim_each]}]
Points sampled from the manifold.
"""
samples = self._iterate_over_factors(
"random_point", {"n_samples": n_samples, "bound": bound}
)
return samples
[docs]
def random_tangent_vec(self, base_point, n_samples=1):
"""Sample on the tangent space from the product distribution.
The distribution used is the product of the distributions used by the
random_tangent_vec methods of each individual factor manifold.
Parameters
----------
base_point : array-like, shape=[..., n, n]
Base point of the tangent space.
Optional, default: None.
n_samples : int
Number of samples.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., {dim, embedding_space.dim, \
[n_manifolds, dim_each]}]
Points sampled in the tangent space of the product manifold at base_point.
"""
samples = self._iterate_over_factors(
"random_tangent_vec", {"base_point": base_point, "n_samples": n_samples}
)
return samples
[docs]
def projection(self, point):
"""Project a point onto product manifold.
Parameters
----------
point : array-like, shape=[..., {dim, embedding_space.dim, \
[n_manifolds, dim_each]}]
Point in product manifold.
Returns
-------
projected : array-like, shape=[..., {dim, embedding_space.dim, \
[n_manifolds, dim_each]}]
Projected point.
"""
projected_point = self._iterate_over_factors("projection", {"point": point})
return projected_point
[docs]
def to_tangent(self, vector, base_point):
"""Project a vector to a tangent space of the manifold.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at base point.
Notes
-----
The tangent space of the product manifold is the direct sum of
tangent spaces.
"""
tangent_vec = self._iterate_over_factors(
"to_tangent", {"base_point": base_point, "vector": vector}
)
return tangent_vec
[docs]
def is_tangent(self, vector, base_point=None, atol=gs.atol):
"""Check whether the vector is tangent at base_point.
The tangent space of the product manifold is the direct sum of
tangent spaces.
Parameters
----------
vector : array-like, shape=[..., dim]
Vector.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Optional, default: None
atol : float
Absolute tolerance.
Optional, default: backend atol.
Returns
-------
is_tangent : bool
Boolean denoting if vector is a tangent vector at the base point.
"""
is_tangent = self._iterate_over_factors(
"is_tangent", {"base_point": base_point, "vector": vector, "atol": atol}
)
return is_tangent
[docs]
class ProductRiemannianMetric(_IterateOverFactorsMixins, RiemannianMetric):
"""Class for product of Riemannian metrics."""
def __init__(self, space):
factors = [factor.metric for factor in space.factors]
factor_signatures = [metric.signature for metric in factors]
sig_pos = sum(sig[0] for sig in factor_signatures)
sig_neg = sum(sig[1] for sig in factor_signatures)
super().__init__(
space=space,
factors=factors,
cum_index=space._cum_index,
pool_outputs=False,
has_mixed_fields=space._has_mixed_fields,
signature=(sig_pos, sig_neg),
)
@property
def shape(self):
"""Shape of space."""
return self._space.shape
@property
def point_ndim(self):
"""Point type of space."""
return self._space.point_ndim
[docs]
def metric_matrix(self, base_point=None):
"""Compute the matrix of the inner-product.
Matrix of the inner-product defined by the Riemmanian metric
at point base_point of the manifold.
Parameters
----------
base_point : array-like, shape=[..., self.shape]
Point on the manifold at which to compute the inner-product matrix.
Optional, default: None.
Returns
-------
matrix : array-like, shape as described below
Matrix of the inner-product at the base point.
The matrix is in block diagonal form with a block for each factor.
Each block is the same size as the metric_matrix for that factor.
"""
factor_matrices = self._iterate_over_factors(
"metric_matrix", {"base_point": base_point}
)
return _block_diagonal(factor_matrices)
[docs]
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
"""Compute the inner-product of two tangent vectors at a base point.
Inner product defined by the Riemannian metric at point `base_point`
between tangent vectors `tangent_vec_a` and `tangent_vec_b`.
Parameters
----------
tangent_vec_a : array-like, shape=[..., self.shape]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., self.shape]
Second tangent vector at base point.
base_point : array-like, shape=[..., self.shape]
Point on the manifold.
Optional, default: None.
Returns
-------
inner_prod : array-like, shape=[...,]
Inner-product of the two tangent vectors.
"""
args = {
"tangent_vec_a": tangent_vec_a,
"tangent_vec_b": tangent_vec_b,
"base_point": base_point,
}
inner_products = self._iterate_over_factors("inner_product", args)
return sum(inner_products)
[docs]
def squared_norm(self, vector, base_point=None):
"""Compute the square of the norm of a vector.
Squared norm of a vector associated to the inner product
at the tangent space at a base point.
Parameters
----------
vector : array-like, shape=[..., self.shape]
Vector.
base_point : array-like, shape=[..., self.shape]
Base point.
Optional, default: None.
Returns
-------
sq_norm : array-like, shape=[...,]
Squared norm.
"""
args = {
"vector": vector,
"base_point": base_point,
}
sq_norms = self._iterate_over_factors("squared_norm", args)
return sum(sq_norms)
[docs]
def exp(self, tangent_vec, base_point):
"""Compute the Riemannian exponential of a tangent vector.
Parameters
----------
tangent_vec : array-like, shape=[..., self.shape]
Tangent vector at a base point.
base_point : array-like, shape=[..., self.shape]
Point on the manifold.
Optional, default: None.
Returns
-------
exp : array-like, shape=[..., self.shape]
Point on the manifold equal to the Riemannian exponential
of tangent_vec at the base point.
"""
args = {"tangent_vec": tangent_vec, "base_point": base_point}
exp = self._iterate_over_factors("exp", args)
return self._space.embed_to_product(exp)
[docs]
def log(self, point, base_point):
"""Compute the Riemannian logarithm of a point.
Parameters
----------
point : array-like, shape=[..., self.shape]
Point on the manifold.
base_point : array-like, shape=[..., self.shape]
Point on the manifold.
Optional, default: None.
Returns
-------
log : array-like, shape=[..., self.shape]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
"""
args = {"point": point, "base_point": base_point}
logs = self._iterate_over_factors("log", args)
return self._space.embed_to_product(logs)
[docs]
def dist(self, point_a, point_b):
"""Geodesic distance between two points.
Parameters
----------
point_a : array-like, shape=[..., self.shape]
Point.
point_b : array-like, shape=[..., self.shape]
Point.
Returns
-------
dist : array-like, shape=[...,]
Distance.
"""
args = {"point_a": point_a, "point_b": point_b}
dists = gs.array(self._iterate_over_factors("dist", args))
return gs.linalg.norm(dists, ord=2, axis=0)
[docs]
def geodesic(self, initial_point, end_point=None, initial_tangent_vec=None):
"""Generate parameterized function for the geodesic curve.
Geodesic curve defined by either:
- an initial point and an initial tangent vector,
- an initial point and an end point.
Parameters
----------
initial_point : array-like, shape=[..., dim]
Point on the manifold, initial point of the geodesic.
end_point : array-like, shape=[..., dim], optional
Point on the manifold, end point of the geodesic. If None,
an initial tangent vector must be given.
initial_tangent_vec : array-like, shape=[..., dim]
Tangent vector at base point, the initial speed of the geodesics.
Optional, default: None.
If None, an end point must be given and a logarithm is computed.
Returns
-------
path : callable
Time parameterized geodesic curve. If a batch of initial
conditions is passed, the output array's first dimension
represents the different initial conditions, and the second
corresponds to time.
"""
args = {
"initial_point": initial_point,
"end_point": end_point,
"initial_tangent_vec": initial_tangent_vec,
}
geodesics = self._iterate_over_factors("geodesic", args)
def geod_fun(t):
t = gs.to_ndarray(t, to_ndim=1)
values = [geodesic(t) for geodesic in geodesics]
return self._space.embed_to_product(values)
return geod_fun