Source code for geomstats.geometry.riemannian_metric

"""Riemannian and pseudo-Riemannian metrics.

Lead author: Nina Miolane.
"""
from abc import ABC

import joblib

import geomstats.backend as gs
import geomstats.geometry as geometry
from geomstats.geometry.connection import Connection

EPSILON = 1e-4
N_CENTERS = 10
N_REPETITIONS = 20
N_MAX_ITERATIONS = 50000
N_STEPS = 10


[docs]class RiemannianMetric(Connection, ABC): """Class for Riemannian and pseudo-Riemannian metrics. The associated Levi-Civita connection on the tangent bundle. Parameters ---------- dim : int Dimension of the manifold. shape : tuple of int Shape of one element of the manifold. Optional, default : (dim, ). signature : tuple Signature of the metric. Optional, default: None. default_point_type : str, {'vector', 'matrix'} Point type. Optional, default: 'vector'. """ def __init__(self, dim, shape=None, signature=None, default_point_type=None): super(RiemannianMetric, self).__init__( dim=dim, shape=shape, default_point_type=default_point_type ) if signature is None: signature = (dim, 0) self.signature = signature
[docs] def metric_matrix(self, base_point=None): """Metric matrix at the tangent space at a base point. Parameters ---------- base_point : array-like, shape=[..., dim] Base point. Optional, default: None. Returns ------- mat : array-like, shape=[..., dim, dim] Inner-product matrix. """ raise NotImplementedError( "The computation of the metric matrix" " is not implemented." )
[docs] def cometric_matrix(self, base_point=None): """Inner co-product matrix at the cotangent space at a base point. This represents the cometric matrix, i.e. the inverse of the metric matrix. Parameters ---------- base_point : array-like, shape=[..., dim] Base point. Optional, default: None. Returns ------- mat : array-like, shape=[..., dim, dim] Inverse of inner-product matrix. """ metric_matrix = self.metric_matrix(base_point) cometric_matrix = gs.linalg.inv(metric_matrix) return cometric_matrix
[docs] def inner_product_derivative_matrix(self, base_point=None): """Compute derivative of the inner prod matrix at base point. Parameters ---------- base_point : array-like, shape=[..., dim] Base point. Optional, default: None. Returns ------- mat : array-like, shape=[..., dim, dim] Derivative of inverse of inner-product matrix. """ metric_derivative = gs.autodiff.jacobian(self.metric_matrix) return metric_derivative(base_point)
[docs] def christoffels(self, base_point): r"""Compute Christoffel symbols of the Levi-Civita connection. The Koszul formula defining the Levi-Civita connection gives the expression of the Christoffel symbols with respect to the metric: :math:`\Gamma^k_{ij}(p) = \frac{1}{2} g^{lk}( \partial_i g_{jl} + \partial_j g_{li} - \partial_l g_{ij})`, where: - :math:`p` represents the base point, and - :math:`g` represents the Riemannian metric tensor. Parameters ---------- base_point: array-like, shape=[..., dim] Base point. Returns ------- christoffels: array-like, shape=[..., dim, dim, dim] Christoffel symbols. """ cometric_mat_at_point = self.cometric_matrix(base_point) metric_derivative_at_point = self.inner_product_derivative_matrix(base_point) term_1 = gs.einsum( "...lk,...jli->...kij", cometric_mat_at_point, metric_derivative_at_point ) term_2 = gs.einsum( "...lk,...lij->...kij", cometric_mat_at_point, metric_derivative_at_point ) term_3 = -gs.einsum( "...lk,...ijl->...kij", cometric_mat_at_point, metric_derivative_at_point ) christoffels = 0.5 * (term_1 + term_2 + term_3) return christoffels
[docs] def inner_product(self, tangent_vec_a, tangent_vec_b, base_point): """Inner product between two tangent vectors at a base point. Parameters ---------- tangent_vec_a: array-like, shape=[..., dim] Tangent vector at base point. tangent_vec_b: array-like, shape=[..., dim] Tangent vector at base point. base_point: array-like, shape=[..., dim] Base point. Optional, default: None. Returns ------- inner_product : array-like, shape=[...,] Inner-product. """ inner_prod_mat = self.metric_matrix(base_point) aux = gs.einsum("...j,...jk->...k", tangent_vec_a, inner_prod_mat) inner_prod = gs.einsum("...k,...k->...", aux, tangent_vec_b) return inner_prod
[docs] def inner_coproduct(self, cotangent_vec_a, cotangent_vec_b, base_point): """Compute inner coproduct between two cotangent vectors at base point. This is the inner product associated to the cometric matrix. Parameters ---------- cotangent_vec_a : array-like, shape=[..., dim] Cotangent vector at `base_point`. cotangent_vet_b : array-like, shape=[..., dim] Cotangent vector at `base_point`. base_point : array-like, shape=[..., dim] Point on the manifold. Returns ------- inner_coproduct : float Inner coproduct between the two cotangent vectors. """ vector_2 = gs.einsum( "...ij,...j->...i", self.cometric_matrix(base_point), cotangent_vec_b ) inner_coproduct = gs.einsum("...i,...i->...", cotangent_vec_a, vector_2) return inner_coproduct
[docs] def hamiltonian(self, state): r"""Compute the hamiltonian energy associated to the cometric. The Hamiltonian at state :math: `(q, p)` is defined by .. math: H(q, p) = \frac{1}{2} <p, p>_q where :math: `<\cdot, \cdot>_q` is the cometric at :math: `q`. Parameters ---------- state : tuple of arrays Position and momentum variables. The position is a point on the manifold, while the momentum is cotangent vector. Returns ------- energy : float Hamiltonian energy at `state`. """ position, momentum = state return 1.0 / 2 * self.inner_coproduct(momentum, momentum, position)
[docs] def squared_norm(self, vector, base_point=None): """Compute the square of the norm of a vector. Squared norm of a vector associated to the inner product at the tangent space at a base point. Parameters ---------- vector : array-like, shape=[..., dim] Vector. base_point : array-like, shape=[..., dim] Base point. Optional, default: None. Returns ------- sq_norm : array-like, shape=[...,] Squared norm. """ sq_norm = self.inner_product(vector, vector, base_point) return sq_norm
[docs] def norm(self, vector, base_point=None): """Compute norm of a vector. Norm of a vector associated to the inner product at the tangent space at a base point. Note: This only works for positive-definite Riemannian metrics and inner products. Parameters ---------- vector : array-like, shape=[..., dim] Vector. base_point : array-like, shape=[..., dim] Base point. Optional, default: None. Returns ------- norm : array-like, shape=[...,] Norm. """ sq_norm = self.squared_norm(vector, base_point) norm = gs.sqrt(sq_norm) return norm
[docs] def normalize(self, vector, base_point): """Normalize tangent vector at a given point. Parameters ---------- vector : array-like, shape=[..., dim] Tangent vector at base_point. base_point : array-like, shape=[..., dim] Point. Returns ------- normalized_vector : array-like, shape=[..., dim] Unit tangent vector at base_point. """ norm = self.norm(vector, base_point) norm = gs.where(norm == 0, gs.ones(norm.shape), norm) normalized_vector = gs.einsum("...i,...->...i", vector, 1 / norm) return normalized_vector
[docs] def random_unit_tangent_vec(self, base_point, n_vectors=1): """Generate a random unit tangent vector at a given point. Parameters ---------- base_point : array-like, shape=[..., dim] Point. n_vectors : float Number of vectors to be generated at base_point. For vectorization purposes n_vectors can be greater than 1 iff base_point constitues of a single point. Returns ------- normalized_vector : array-like, shape=[..., n_vectors, dim] Random unit tangent vector at base_point. """ shape = base_point.shape if len(shape) > 1 and shape[-2] > 1 and n_vectors > 1: raise ValueError( "Several tangent vectors is only applicable to a single base point." ) random_vector = gs.squeeze(gs.random.rand(n_vectors, *shape)) normalized_vector = self.normalize(random_vector, base_point) return gs.squeeze(normalized_vector)
[docs] def squared_dist(self, point_a, point_b, **kwargs): """Squared geodesic distance between two points. Parameters ---------- point_a : array-like, shape=[..., dim] Point. point_b : array-like, shape=[..., dim] Point. Returns ------- sq_dist : array-like, shape=[...,] Squared distance. """ log = self.log(point=point_b, base_point=point_a, **kwargs) sq_dist = self.squared_norm(vector=log, base_point=point_a) return sq_dist
[docs] def dist(self, point_a, point_b, **kwargs): """Geodesic distance between two points. Note: It only works for positive definite Riemannian metrics. Parameters ---------- point_a : array-like, shape=[..., dim] Point. point_b : array-like, shape=[..., dim] Point. Returns ------- dist : array-like, shape=[...,] Distance. """ sq_dist = self.squared_dist(point_a, point_b, **kwargs) dist = gs.sqrt(sq_dist) return dist
[docs] def dist_broadcast(self, point_a, point_b): """Compute the geodesic distance between points. If n_samples_a == n_samples_b then dist is the element-wise distance result of a point in points_a with the point from points_b of the same index. If n_samples_a not equal to n_samples_b then dist is the result of applying geodesic distance for each point from points_a to all points from points_b. Parameters ---------- point_a : array-like, shape=[n_samples_a, dim] Set of points in the Poincare ball. point_b : array-like, shape=[n_samples_b, dim] Second set of points in the Poincare ball. Returns ------- dist : array-like, shape=[n_samples_a, dim] or [n_samples_a, n_samples_b, dim] Geodesic distance between the two points. """ ndim = len(self.shape) if point_a.shape[-ndim:] != point_b.shape[-ndim:]: raise ValueError("Manifold dimensions not equal") if ndim in (point_a.ndim, point_b.ndim) or (point_a.shape == point_b.shape): return self.dist(point_a, point_b) n_samples = point_a.shape[0] * point_b.shape[0] point_a_broadcast, point_b_broadcast = gs.broadcast_arrays( point_a[:, None], point_b[None, ...] ) point_a_flatten = gs.reshape( point_a_broadcast, (n_samples,) + point_a.shape[-ndim:] ) point_b_flatten = gs.reshape( point_b_broadcast, (n_samples,) + point_a.shape[-ndim:] ) dist = self.dist(point_a_flatten, point_b_flatten) dist = gs.reshape(dist, (point_a.shape[0], point_b.shape[0])) dist = gs.squeeze(dist) return dist
[docs] def dist_pairwise(self, points, n_jobs=1, **joblib_kwargs): """Compute the pairwise distance between points. Parameters ---------- points : array-like, shape=[n_samples, dim] Set of points in the manifold. n_jobs : int Number of jobs to run in parallel, using joblib. Note that a higher number of jobs may not be beneficial when one computation of a geodesic distance is cheap. Optional. Default: 1. **joblib_kwargs : dict Keyword arguments to joblib.Parallel Returns ------- dist : array-like, shape=[n_samples, n_samples] Pairwise distance matrix between all the points. See Also -------- `joblib documentations <https://joblib.readthedocs.io/en/latest/>`_ """ n_samples = points.shape[0] rows, cols = gs.triu_indices(n_samples) @joblib.delayed @joblib.wrap_non_picklable_objects def pickable_dist(x, y): """Wrap distance function to make it pickable.""" return self.dist(x, y) pool = joblib.Parallel(n_jobs=n_jobs, **joblib_kwargs) out = pool(pickable_dist(points[i], points[j]) for i, j in zip(rows, cols)) pairwise_dist = geometry.symmetric_matrices.SymmetricMatrices.from_vector( gs.array(out) ) return pairwise_dist
[docs] def diameter(self, points): """Give the distance between two farthest points. Distance between the two points that are farthest away from each other in points. Parameters ---------- points : array-like, shape=[..., dim] Points. Returns ------- diameter : float Distance between two farthest points. """ diameter = 0.0 n_points = points.shape[0] for i in range(n_points - 1): dist_to_neighbors = self.dist(points[i, :], points[i + 1 :, :]) dist_to_farthest_neighbor = gs.amax(dist_to_neighbors) diameter = gs.maximum(diameter, dist_to_farthest_neighbor) return diameter
[docs] def closest_neighbor_index(self, point, neighbors): """Closest neighbor of point among neighbors. Parameters ---------- point : array-like, shape=[..., dim] Point. neighbors : array-like, shape=[n_neighbors, dim] Neighbors. Returns ------- closest_neighbor_index : int Index of closest neighbor. """ n_points = point.shape[0] if gs.ndim(point) == gs.ndim(neighbors) else 1 n_neighbors = neighbors.shape[0] if n_points > 1 and n_neighbors > 1: neighbors = gs.repeat(neighbors, n_points, axis=0) point = gs.concatenate([point for _ in range(n_neighbors)]) closest_neighbor_index = gs.argmin( gs.transpose( gs.reshape(self.dist(point, neighbors), (n_neighbors, n_points)), ), axis=1, ) if n_points == 1: return closest_neighbor_index[0] return closest_neighbor_index
[docs] def normal_basis(self, basis, base_point=None): """Normalize the basis with respect to the metric. This corresponds to a renormalization of each basis vector. Parameters ---------- basis : array-like, shape=[dim, dim] Matrix of a metric. base_point Returns ------- basis : array-like, shape=[dim, n, n] Normal basis. """ norms = self.squared_norm(basis, base_point) return gs.einsum("i, ikl->ikl", 1.0 / gs.sqrt(norms), basis)
[docs] def sectional_curvature(self, tangent_vec_a, tangent_vec_b, base_point=None): r"""Compute the sectional curvature. For two orthonormal tangent vectors :math:`x,y` at a base point, the sectional curvature is defined by :math:`<R(x, y)x, y> = <R_x(y), y>`. For non-orthonormal vectors vectors, it is :math:`<R(x, y)x, y> / \\|x \\wedge y\\|^2`. Parameters ---------- tangent_vec_a : array-like, shape=[..., n, n] Tangent vector at `base_point`. tangent_vec_b : array-like, shape=[..., n, n] Tangent vector at `base_point`. base_point : array-like, shape=[..., n, n] Point in the group. Optional, default is the identity Returns ------- sectional_curvature : array-like, shape=[...,] Sectional curvature at `base_point`. See Also -------- https://en.wikipedia.org/wiki/Sectional_curvature """ curvature = self.curvature( tangent_vec_a, tangent_vec_b, tangent_vec_a, base_point ) sectional = self.inner_product(curvature, tangent_vec_b, base_point) norm_a = self.squared_norm(tangent_vec_a, base_point) norm_b = self.squared_norm(tangent_vec_b, base_point) inner_ab = self.inner_product(tangent_vec_a, tangent_vec_b, base_point) normalization_factor = norm_a * norm_b - inner_ab**2 condition = gs.isclose(normalization_factor, 0.0) normalization_factor = gs.where(condition, EPSILON, normalization_factor) return gs.where(~condition, sectional / normalization_factor, 0.0)