Source code for geomstats.geometry.product_riemannian_metric

"""Product of Riemannian metrics.

Define the metric of a product manifold endowed with a product metric.

"""

import joblib

import geomstats.backend as gs
import geomstats.errors
from geomstats.geometry.riemannian_metric import RiemannianMetric

[docs]class ProductRiemannianMetric(RiemannianMetric): """Class for product of Riemannian metrics. Parameters ---------- metrics : list List of metrics in the product. default_point_type : str, {'vector', 'matrix'} Point type. Optional, default: 'vector'. n_jobs : int Number of jobs for parallel computing. Optional, default: 1. """ def __init__(self, metrics, default_point_type="vector", n_jobs=1): self.n_metrics = len(metrics) dims = [metric.dim for metric in metrics] signatures = [metric.signature for metric in metrics] sig_pos = sum(sig[0] for sig in signatures) sig_neg = sum(sig[1] for sig in signatures) super(ProductRiemannianMetric, self).__init__( dim=sum(dims), signature=(sig_pos, sig_neg), default_point_type=default_point_type, ) self.metrics = metrics self.dims = dims self.signatures = signatures self.n_jobs = n_jobs
[docs] def metric_matrix(self, base_point=None, point_type=None): """Compute the matrix of the inner-product. Matrix of the inner-product defined by the Riemmanian metric at point base_point of the manifold. Parameters ---------- base_point : array-like, shape=[..., n_metrics, dim] or [..., dim] Point on the manifold at which to compute the inner-product matrix. Optional, default: None. point_type : str, {'vector', 'matrix'} Type of representation used for points. Optional, default: None. Returns ------- matrix : array-like, shape=[..., dim, dim] or [..., dim + n_metrics, dim + n_metrics] Matrix of the inner-product at the base point. """ if point_type is None: point_type = self.default_point_type base_point = gs.to_ndarray(base_point, to_ndim=2) base_point = gs.to_ndarray(base_point, to_ndim=3) matrix = gs.zeros([len(base_point), self.dim, self.dim]) cum_dim = 0 for i in range(self.n_metrics): cum_dim_next = cum_dim + self.dims[i] if point_type == "matrix": matrix_next = self.metrics[i].metric_matrix(base_point[:, i]) elif point_type == "vector": matrix_next = self.metrics[i].metric_matrix( base_point[:, cum_dim:cum_dim_next, cum_dim:cum_dim_next] ) else: raise ValueError( "invalid point_type argument: {}, expected " "either matrix of vector".format(point_type) ) matrix[:, cum_dim:cum_dim_next, cum_dim:cum_dim_next] = matrix_next cum_dim = cum_dim_next return matrix[0] if len(base_point) == 1 else matrix
[docs] def is_intrinsic(self, point): """Test in a point is represented in intrinsic coordinates. This method is only useful for `point_type=vector`. Parameters ---------- point : array-like, shape=[..., dim] Point on the product manifold. Returns ------- intrinsic : array-like, shape=[...,] Whether intrinsic coordinates are used for all manifolds. """ if self.default_point_type != "vector": raise ValueError("Invalid default_point_type: 'vector' expected.") if point.shape[-1] == self.dim: return True if point.shape[-1] == sum(dim + 1 for dim in self.dims): return False raise ValueError("Input shape does not match the dimension of the manifold")
@staticmethod def _get_method(metric, method_name, metric_args): return getattr(metric, method_name)(**metric_args) def _iterate_over_metrics(self, func, args, intrinsic=False): cum_index = ( gs.cumsum(self.dims)[:-1] if intrinsic else gs.cumsum(gs.array([k + 1 for k in self.dims])) ) arguments = { key: gs.split(args[key], cum_index, axis=-1) for key in args.keys() } args_list = [ {key: arguments[key][j] for key in args.keys()} for j in range(self.n_metrics) ] pool = joblib.Parallel(n_jobs=self.n_jobs, prefer="threads") out = pool( joblib.delayed(self._get_method)(self.metrics[i], func, args_list[i]) for i in range(self.n_metrics) ) return out
[docs] def inner_product( self, tangent_vec_a, tangent_vec_b, base_point=None, point_type=None ): """Compute the inner-product of two tangent vectors at a base point. Inner product defined by the Riemannian metric at point `base_point` between tangent vectors `tangent_vec_a` and `tangent_vec_b`. Parameters ---------- tangent_vec_a : array-like, shape=[..., dim + 1] First tangent vector at base point. tangent_vec_b : array-like, shape=[..., dim + 1] Second tangent vector at base point. base_point : array-like, shape=[..., dim + 1] Point on the manifold. Optional, default: None. point_type : str, {'vector', 'matrix'} Type of representation used for points. Optional, default: None. Returns ------- inner_prod : array-like, shape=[...,] Inner-product of the two tangent vectors. """ if base_point is None: base_point = gs.empty((self.n_metrics, self.dim)) if point_type is None: point_type = self.default_point_type geomstats.errors.check_parameter_accepted_values( point_type, "point_type", ["vector", "matrix"] ) if point_type == "vector": intrinsic = self.is_intrinsic(tangent_vec_b) args = { "tangent_vec_a": tangent_vec_a, "tangent_vec_b": tangent_vec_b, "base_point": base_point, } inner_prod = self._iterate_over_metrics("inner_product", args, intrinsic) return gs.sum(gs.stack(inner_prod, axis=-2), axis=-2) inner_products = [ metric.inner_product( tangent_vec_a[..., i, :], tangent_vec_b[..., i, :], base_point[..., i, :], ) for i, metric in enumerate(self.metrics) ] return sum(inner_products)
[docs] def exp(self, tangent_vec, base_point=None, point_type=None, **kwargs): """Compute the Riemannian exponential of a tangent vector. Parameters ---------- tangent_vec : array-like, shape=[..., dim] Tangent vector at a base point. base_point : array-like, shape=[..., dim] Point on the manifold. Optional, default: None. point_type : str, {'vector', 'matrix'} Type of representation used for points. Optional, default: None. Returns ------- exp : array-like, shape=[..., dim] Point on the manifold equal to the Riemannian exponential of tangent_vec at the base point. """ if base_point is None: base_point = [ None, ] * self.n_metrics if point_type is None: point_type = self.default_point_type geomstats.errors.check_parameter_accepted_values( point_type, "point_type", ["vector", "matrix"] ) if point_type == "vector": intrinsic = self.is_intrinsic(base_point) args = {"tangent_vec": tangent_vec, "base_point": base_point} exp = self._iterate_over_metrics("exp", args, intrinsic) return gs.concatenate(exp, -1) exp = gs.stack( [ self.metrics[i].exp(tangent_vec[..., i, :], base_point[..., i, :]) for i in range(self.n_metrics) ], axis=-2, ) return exp[0] if len(tangent_vec) == 1 else exp
[docs] def log(self, point, base_point=None, point_type=None, **kwargs): """Compute the Riemannian logarithm of a point. Parameters ---------- point : array-like, shape=[..., dim] Point on the manifold. base_point : array-like, shape=[..., dim] Point on the manifold. Optional, default: None. point_type : str, {'vector', 'matrix'} Type of representation used for points. Optional, default: None. Returns ------- log : array-like, shape=[..., dim] Tangent vector at the base point equal to the Riemannian logarithm of point at the base point. """ if base_point is None: base_point = [None] * self.n_metrics if point_type is None: point_type = self.default_point_type geomstats.errors.check_parameter_accepted_values( point_type, "point_type", ["vector", "matrix"] ) if point_type == "vector": intrinsic = self.is_intrinsic(base_point) args = {"point": point, "base_point": base_point} logs = self._iterate_over_metrics("log", args, intrinsic) logs = gs.concatenate(logs, axis=-1) return logs logs = gs.stack( [ self.metrics[i].log(point[..., i, :], base_point[..., i, :]) for i in range(self.n_metrics) ], axis=-2, ) return logs