Source code for geomstats.geometry.stratified.point_set

"""Class for Stratified Spaces.

Lead authors: Anna Calissano & Jonas Lueg
"""

import functools
import itertools
from abc import ABC, abstractmethod


[docs] def broadcast_lists(list_a, list_b): """Broadcast two lists. Similar behavior as ``gs.broadcast_arrays``, but for lists. """ n_a = len(list_a) n_b = len(list_b) if n_a == n_b: return list_a, list_b if n_a == 1: return itertools.zip_longest(list_a, list_b, fillvalue=list_a[0]) if n_b == 1: return itertools.zip_longest(list_a, list_b, fillvalue=list_b[0]) raise Exception(f"Cannot broadcast lens {n_a} and {n_b}")
def _manipulate_input(arg): if not (type(arg) in [list, tuple]): return [arg] return arg def _vectorize_point(*args_positions, manipulate_input=_manipulate_input): """Check point type and transform in iterable if not the case. Parameters ---------- args_positions : tuple Position and corresponding argument name. A tuple for each position. Notes ----- Explicitly defining args_positions and args names ensures it works for all combinations of input calling. """ def _dec(func): @functools.wraps(func) def _wrapped(*args, **kwargs): args = list(args) for pos, name in args_positions: if name in kwargs: kwargs[name] = manipulate_input(kwargs[name]) else: args[pos] = manipulate_input(args[pos]) return func(*args, **kwargs) return _wrapped return _dec
[docs] class Point(ABC): r"""Class for points of a set.""" @abstractmethod def __repr__(self): """Produce a string with a verbal description of the point.""" @abstractmethod def __hash__(self): """Define a hash for the point."""
[docs] @abstractmethod def to_array(self): """Turn the point into a numpy array. Returns ------- array_point : array-like, shape=[...] An array representation of the Point type. """
[docs] class PointSet(ABC): r"""Class for a set of points of type Point.""" def __init__(self, equip=True): if equip: self.equip_with_metric()
[docs] def equip_with_metric(self, Metric=None, **metric_kwargs): """Equip manifold with Metric. Parameters ---------- Metric : PointSetMetric object If None, default metric will be used. """ if Metric is None: out = self.default_metric() if isinstance(out, tuple): Metric, kwargs = out kwargs.update(metric_kwargs) metric_kwargs = kwargs else: Metric = out self.metric = Metric(self, **metric_kwargs)
[docs] @abstractmethod def belongs(self, point, atol): r"""Evaluate if a point belongs to the set. Parameters ---------- point : Point-like, shape=[...] Point to evaluate. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- belongs : array-like, shape=[...] Boolean evaluating if point belongs to the set. """
[docs] @abstractmethod def random_point(self, n_samples=1): r"""Sample random points on the PointSet. Parameters ---------- n_samples : int Number of samples. Optional, default: 1. Returns ------- samples : List of Point Points sampled on the PointSet. """
[docs] @abstractmethod def set_to_array(self, points): """Convert a set of points into an array. Parameters ---------- points : list of Point, shape=[...] Number of samples of point type to turn into an array. Returns ------- points_array : array-like, shape=[...] Points sampled on the PointSet. """
[docs] class PointSetMetric(ABC): r"""Class for the lenght spaces.""" def __init__(self, space): self._space = space
[docs] @abstractmethod def dist(self, point_a, point_b, **kwargs): """Distance between two points in the PointSet. Parameters ---------- point_a: Point or List of Point, shape=[...] Point in the PointSet. point_b: Point or List of Point, shape=[...] Point in the PointSet. Returns ------- distance : array-like, shape=[...] Distance. """
[docs] @abstractmethod def geodesic(self, initial_point, end_point, **kwargs): """Compute the geodesic in the PointSet. Parameters ---------- initial_point: Point or List of Points, shape=[...] Point in the PointSet. end_point: Point or List of Points, shape=[...] Point in the PointSet. Returns ------- path : callable Time parameterized geodesic curve. """