Source code for geomstats.geometry.sasaki_metric

"""Class for the Sasaki metric.

A class implementing the Sasaki metric: The natural metric on the tangent
bundle TM of a Riemannian manifold M.

Lead authors: E. Nava-Yazdani, F. Ambellan, M. Hanik and C. von Tycowicz.
"""

from joblib import Parallel, delayed

import geomstats.backend as gs
from geomstats.geometry.base import Manifold
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.vectorization import check_is_batch


[docs] class GradientDescent: """Gradient descent algorithm.""" def __init__(self, lrate=0.1, max_iter=100, tol=1e-6): self.lrate = lrate self.max_iter = max_iter self.tol = tol
[docs] def minimize(self, x_ini, i_pt, e_pt, grad, exp): """Apply a gradient descent until max_iter or a given tolerance is reached.""" x = x_ini for _ in range(self.max_iter): grad_x = grad(x, i_pt, e_pt) grad_norm = gs.linalg.norm(grad_x) if grad_norm < self.tol: break grad_x = -self.lrate * grad_x x = exp(grad_x, x) return x
[docs] class TangentBundle(Manifold): """Tangent bundle of a space.""" def __init__(self, space, equip=True): self.space = space super().__init__(dim=2 * space.dim, shape=(2,) + space.shape, equip=equip)
[docs] @staticmethod def default_metric(): """Metric to equip the space with if equip is True.""" return SasakiMetric
def _unstack(self, point): return ( point[..., 0, -self.space.point_ndim + 1 :], point[..., 1, -self.space.point_ndim + 1 :], ) def _stack(self, space_point, space_tangent_vec): return gs.stack([space_point, space_tangent_vec], axis=-self.point_ndim)
[docs] def belongs(self, point, atol=gs.atol): """Evaluate if a point belongs to the manifold. Parameters ---------- point : array-like, shape=[..., *point_shape] Point to evaluate. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- belongs : array-like, shape=[...,] Boolean evaluating if point belongs to the manifold. """ space_point, space_tangent_vec = self._unstack(point) return gs.logical_and( self.space.belongs(space_point), self.space.is_tangent(space_tangent_vec, space_point), )
[docs] def random_point(self, n_samples=1, bound=1.0): """Sample random points on the manifold according to some distribution. If the manifold is compact, preferably a uniform distribution will be used. Parameters ---------- n_samples : int Number of samples. Optional, default: 1. bound : float Bound of the interval in which to sample for non compact manifolds. Optional, default: 1. Returns ------- samples : array-like, shape=[..., *point_shape] Points sampled on the manifold. """ space_point = self.space.random_point(n_samples, bound) space_tangent_vec = self.space.random_tangent_vec(space_point) return self._stack(space_point, space_tangent_vec)
[docs] @staticmethod def projection(point): """Project a point to the vector space. This method is for compatibility and returns `point`. `point` should have the right shape, Parameters ---------- point: array-like, shape[..., *point_shape] Point. Returns ------- point: array-like, shape[..., *point_shape] Point. """ return gs.copy(point)
[docs] def is_tangent(self, vector, base_point=None, atol=gs.atol): """Check whether the vector is tangent at base_point. Tangent vectors are identified with points of the vector space so this checks the shape of the input vector. Parameters ---------- vector : array-like, shape=[..., *point_shape] Vector. base_point : array-like, shape=[..., *point_shape] Point in the vector space. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- is_tangent : array-like, shape=[...,] Boolean denoting if vector is a tangent vector at the base point. """ raise NotImplementedError("`is_tangent` is not implemented")
[docs] def to_tangent(self, vector, base_point=None): """Project a vector to a tangent space. This method is for compatibility and returns vector. Parameters ---------- vector : array-like, shape=[..., *point_shape] Vector. base_point : array-like, shape=[..., *point_shape] Point in the vector space Returns ------- tangent_vec : array-like, shape=[..., *point_shape] Tangent vector at base point. """ raise NotImplementedError("`to_tangent` is not implemented")
[docs] def random_tangent_vec(self, base_point, n_samples=1): """Generate random tangent vec.""" raise NotImplementedError("`random_tangent_vec` is not implemented")
[docs] class SasakiMetric(RiemannianMetric): """Implements of the Sasaki metric on the tangent bundle TM of a Riem. manifold M. The Sasaki metric is characterized by the following three properties: * the canonical projection of TM becomes a Riemannian submersion, * parallel vector fields along curves are orthogonal to their fibres, and * its restriction to any tangent space is Euclidean. Geodesic computations are realized via a discrete formulation of the geodesic equation on TM that involve geodesics, parallel translation, and the curvature tensor on the base manifold M (see [1]_ for details). However, as the implemented energy in the discrete-geodesics-optimization as well as the approximations of its gradient slightly differ from those proposed in [1]_, we also refer to [2]_ for additional details. Parameters ---------- space : Manifold Tangent bundle. n_jobs: int Number of jobs for parallel computing. Optional, default: 1. n_steps : int Number of discrete time steps. Optional, default: 3. References ---------- .. [1] Muralidharan, P., & Fletcher, P. T. "Sasaki metrics for analysis of longitudinal data on manifolds", IEEE CVPR 2012, pp. 1027-1034 https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4270017/ .. [2] Nava-Yazdani, E., Hanik, M., Ambellan, F., & von Tycowicz, C. "On Gradient Formulas in an Algorithm for the Logarithm of the Sasaki Metric", Technical Report Zuse-Institut Berlin, 2022 https://nbn-resolving.org/urn/resolver.pl?urn:nbn:de:0297-zib-87174 """ def __init__(self, space, n_jobs=1, n_steps=3): super().__init__(space=space) self.n_jobs = n_jobs self.n_steps = n_steps self._gradient_descent = GradientDescent()
[docs] def exp(self, tangent_vec, base_point): """Compute the Riemannian exponential of a point. Exponential map at base_point of tangent_vec computed by shooting a Sasaki geodesic using an Euler integration on TM. Parameters ---------- tangent_vec : array-like, shape=[..., 2, M.dim] Tangent vector in TTM at the base point in TM. base_point : array-like, shape=[..., 2, M.dim] Point in the tangent bundle TM of manifold M. Returns ------- exp : array-like, shape=[..., 2, M.dim] Point on the tangent bundle TM. """ par_trans = self._space.space.metric.parallel_transport eps = 1 / self.n_steps v0, w0 = self._space._unstack(tangent_vec) p0, u0 = self._space._unstack(base_point) for _ in range(self.n_steps): p = self._space.space.metric.exp(eps * v0, p0) u = par_trans(u0 + eps * w0, p0, end_point=p) v = par_trans( v0 - eps * (self._space.space.metric.curvature(u0, w0, v0, p0)), p0, end_point=p, ) w = par_trans(w0, p0, end_point=p) p0, u0 = p, u v0, w0 = v, w return self._space._stack(p, u)
[docs] def log(self, point, base_point): """Compute the Riemannian logarithm of a point. Logarithmic map at base_point of point computed by iteratively relaxing a discretized geodesic between base_point and point. Parameters ---------- point : array-like, shape=[..., 2, M.dim] Point in the tangent bundle TM of manifold M. base_point : array-like, shape=[..., 2, M.dim] Point in the tangent bundle TM of manifold M. Returns ------- log : array-like, shape=[..., 2, M.dim] Tangent vector at the base point equal to the Riemannian logarithm of point at the base point. """ par_trans = self._space.space.metric.parallel_transport pu = self.geodesic_discrete(base_point, point) pu1 = gs.take(pu, 1, axis=-(self._space.point_ndim + 1)) p1, u1 = self._space._unstack(pu1) p0, u0 = self._space._unstack(base_point) w = par_trans(u1, p1, end_point=p0) - u0 v = self._space.space.metric.log(point=p1, base_point=p0) return self.n_steps * self._space._stack(v, w)
[docs] def geodesic_discrete(self, initial_point, end_point): """Compute Sakai geodesic employing a variational time discretization. Parameters ---------- end_points : array-like, shape=[..., 2, M.shape] Points in the tangent bundle TM of manifold M. initial_points : array-like, shape=[..., 2, M.shape] Points in the tangent bundle TM of manifold M. Returns ------- geodesic : array-like, shape=[..., n_steps + 1, 2, M.shape] Discrete geodesics of form x(s)=(p(s), u(s)) in Sasaki metric connecting initial_point = x(0) and end_point = x(1). """ metric = self._space.space.metric par_trans = metric.parallel_transport def _grad(pu, i_pt, e_pt): """Gradient of discrete geodesic energy.""" pu = gs.vstack( [ gs.expand_dims(i_pt, axis=0), pu, gs.expand_dims(e_pt, axis=0), ] ) p, u = self._space._unstack(pu) p1, p2, p3 = p[:-2], p[1:-1], p[2:] u1, u2, u3 = u[:-2], u[1:-1], u[2:] eps = 1 / self.n_steps v2 = metric.log(p3, p2) / eps w2 = (par_trans(u3, p3, end_point=p2) - u2) / eps gp = (metric.log(p3, p2) + metric.log(p1, p2)) / ( 2 * eps**2 ) - metric.curvature(u2, w2, v2, p2) gu = ( par_trans(u3, p3, end_point=p2) - 2 * u2 + par_trans(u1, p1, end_point=p2) ) / eps**2 return -self._space._stack(gp, gu) * eps def _geodesic_discrete_single(initial_point, end_point): """Calculate the discrete geodesic.""" ijk = "ijk"[: self._space.space.point_ndim] def _scalarmul(scalar, point): return gs.einsum(f"...,...{ijk}->...{ijk}", scalar, point) p0, u0 = initial_point[0], initial_point[1] pL, uL = end_point[0], end_point[1] v = metric.log(pL, p0) s = gs.linspace(0.0, 1.0, self.n_steps + 1)[1:-1] p_ini = metric.exp(gs.einsum(f"p, {ijk}->p{ijk}", s, v), p0) u_ini = _scalarmul( (1.0 - s), par_trans(u0, p0, end_point=p_ini) ) + _scalarmul(s, par_trans(uL, pL, end_point=p_ini)) pu_ini = self._space._stack(p_ini, u_ini) x = self._gradient_descent.minimize( pu_ini, initial_point, end_point, _grad, self.exp ) return gs.vstack( [ gs.expand_dims(initial_point, axis=0), x, gs.expand_dims(end_point, axis=0), ] ) is_batch = check_is_batch(self._space.point_ndim, initial_point, end_point) if not is_batch: return _geodesic_discrete_single(initial_point, end_point) initial_point, end_point = gs.broadcast_arrays(initial_point, end_point) with Parallel(n_jobs=min(self.n_jobs, len(end_point)), verbose=0) as parallel: rslt = parallel( delayed(_geodesic_discrete_single)(i_pt, e_pt) for i_pt, e_pt in zip(initial_point, end_point) ) return gs.array(rslt)
[docs] def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """Inner product between two tangent vectors at a base point. Parameters ---------- tangent_vec_a : array-like, shape=[..., 2, M.dim] Tangent vector in TTM of the tangent bundle TM. tangent_vec_b : array-like, shape=[..., 2, M.dim] Tangent vector in TTM of the tangent bundle TM. base_point : array-like, shape=[..., 2, M.dim] Point in the tangent bundle TM of manifold M. Returns ------- inner_product : array-like, shape=[..., 1] Inner-product. """ vec_a_0, vec_a_1 = self._space._unstack(tangent_vec_a) vec_b_0, vec_b_1 = self._space._unstack(tangent_vec_b) pt, _ = self._space._unstack(base_point) inner = self._space.space.metric.inner_product return inner(vec_a_0, vec_b_0, pt) + inner(vec_a_1, vec_b_1, pt)