Source code for geomstats.geometry.stratified.point_set

"""Class for Stratified Spaces.

Lead authors: Anna Calissano and Jonas Lueg
"""

from abc import ABC, abstractmethod

import geomstats.backend as gs


[docs] class Point(ABC): """Class for points of a set."""
[docs] @abstractmethod def equal(self, point, atol=gs.atol): """Check equality against another point. Parameters ---------- point : Point or PointBatch Point to compare against point. atol : float Returns ------- is_equal : array-like, shape=[...] """
[docs] class PointBatch(ABC, list): """Class for point batch.""" def __getitem__(self, key): """Get item.""" if isinstance(key, int): return list.__getitem__(self, key) return self.__class__(list.__getitem__(self, key))
[docs] def equal(self, point, atol=gs.atol): """Check equality against another point. Parameters ---------- point : Point or PointBatch Point to compare against point. atol : float """ if isinstance(point, (list, tuple)): return gs.array( [ collection_point.equal(point_, atol) for collection_point, point_ in zip(self, point) ] ) return gs.array( [collection_point.equal(point, atol) for collection_point in self] )
[docs] class PointSet(ABC): r"""Class for a set of points of type Point.""" def __init__(self, equip=True): if equip: self.equip_with_metric()
[docs] def equip_with_metric(self, Metric=None, **metric_kwargs): """Equip manifold with Metric. Parameters ---------- Metric : PointSetMetric object If None, default metric will be used. """ if Metric is None: out = self.default_metric() if isinstance(out, tuple): Metric, kwargs = out kwargs.update(metric_kwargs) metric_kwargs = kwargs else: Metric = out self.metric = Metric(self, **metric_kwargs)
[docs] @abstractmethod def belongs(self, point, atol=gs.atol): r"""Evaluate if a point belongs to the set. Parameters ---------- point : Point or PointBatch Point to evaluate. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- belongs : array-like, shape=[...] Boolean evaluating if point belongs to the set. """
[docs] @abstractmethod def random_point(self, n_samples=1): r"""Sample random points on the PointSet. Parameters ---------- n_samples : int Number of samples. Optional, default: 1. Returns ------- samples : Point or PointBatch Points sampled on the PointSet. """
[docs] class PointSetMetric(ABC): r"""Class for the lenght spaces. Parameters ---------- space : PointSet Set to equip with metric. """ def __init__(self, space): self._space = space
[docs] @abstractmethod def dist(self, point_a, point_b): """Distance between two points in the PointSet. Parameters ---------- point_a: Point or PointBatch Point in the PointSet. point_b: Point or PointBatch Point in the PointSet. Returns ------- distance : array-like, shape=[...] Distance. """
[docs] @abstractmethod def geodesic(self, initial_point, end_point): """Compute the geodesic in the PointSet. Parameters ---------- initial_point: Point or PointBatch Point in the PointSet. end_point: Point or PointBatch Point in the PointSet. Returns ------- path : callable Time parameterized geodesic curve. """