Source code for geomstats.geometry.stratified.quotient

"""Quotient structure for a geodesic metric space."""

from abc import ABC

import geomstats.backend as gs
from geomstats.geometry.stratified.point_set import PointSetMetric


[docs] class Aligner(ABC): """Bundle structure. Parameters ---------- total_space : PointSet Set with quotient structure. align_algo : AlignerAlgorithm Algorihtm performing alignment. """ def __init__(self, total_space, align_algo=None): self._total_space = total_space if align_algo is not None: self.align_algo = align_algo
[docs] def align(self, point, base_point): """Align point to base point. Parameters ---------- point : array-like, shape=[..., *point_shape] Point to align. base_point : array-like, shape=[..., *point_shape] Reference point. Returns ------- aligned_point: list, shape = [..., *point_shape] """ if hasattr(self, "align_algo"): return self.align_algo.align(point, base_point) raise NotImplementedError("`align` is not implemented")
[docs] class QuotientMetric(PointSetMetric, ABC): """Quotient metric. Parameters ---------- space : PointSet Set to equip with metric. total_space : PointSet Set with quotient structure. """ def __init__(self, space, total_space): self._total_space = total_space super().__init__(space)
[docs] def squared_dist(self, point_a, point_b): """Compute distance between two points. Parameters ---------- point_a : array-like, shape=[..., *point_shape] point_b : array-like, shape=[..., *point_shape] Returns ------- distance : array-like, shape=[...] Distance between the points. """ aligned_point_b = self._total_space.aligner.align(point_b, point_a) return self._total_space.metric.squared_dist( point_a, aligned_point_b, )
[docs] def dist(self, point_a, point_b): """Compute distance between two points. Parameters ---------- point_a : array-like, shape=[..., *point_shape] point_b : array-like, shape=[..., *point_shape] Returns ------- distance : array-like, shape=[...] Distance between the points. """ return gs.sqrt(self.squared_dist(point_a, point_b))
[docs] def geodesic(self, initial_point, end_point): """Compute geodesic between two points. Parameters ---------- initial_point : array-like, shape=[..., *point_shape] Initial point. end_point : array-like, shape=[..., *point_shape] End point. Returns ------- geodesic : callable Geodesic function. """ aligned_end_point = self._total_space.aligner.align(end_point, initial_point) return self._total_space.metric.geodesic( initial_point=initial_point, end_point=aligned_end_point )