Source code for geomstats.numerics.path

"""Discrete-path related machinery."""

import geomstats.backend as gs
from geomstats.numerics.finite_differences import forward_difference
from geomstats.numerics.interpolation import UniformUnitIntervalLinearInterpolator


[docs] class UniformlySampledPathEnergy: """Riemannian path energy of a uniformly-sampled path. Parameters ---------- space : Manifold Equipped manifold. """ def __init__(self, space): self._space = space def __call__(self, path): """Compute Riemannian path energy. Parameters ---------- path : array-like, shape=[..., n_times, *point_shape] Piecewise linear path. Returns ------- energy : array-like, shape=[...,] Path energy. """ return self.energy(path)
[docs] def energy_per_time(self, path): """Compute Riemannian path enery per time. Parameters ---------- path : array-like, shape=[..., n_times, *point_shape] Piecewise linear path. Returns ------- energy : array-like, shape=[..., n_times - 1,] Stepwise path energy. """ time_axis = -(self._space.point_ndim + 1) point_ndim_slc = tuple([slice(None)] * self._space.point_ndim) n_time = path.shape[time_axis] tangent_vecs = forward_difference(path, axis=time_axis) return self._space.metric.squared_norm( tangent_vecs, path[(..., slice(0, -1)) + point_ndim_slc], ) / (2 * (n_time - 1))
[docs] def energy(self, path): """Compute Riemannian path energy. Parameters ---------- path : array-like, shape=[..., n_times, *point_shape] Piecewise linear path. Returns ------- energy : array-like, shape=[...,] Path energy. """ return gs.sum(self.energy_per_time(path), axis=-1)
[docs] class UniformlySampledDiscretePath: """A uniformly-sampled discrete path. Parameters ---------- path : array-like, [..., *point_shape] interpolator : Interpolator1D """ def __init__(self, path, interpolator=None, **interpolator_kwargs): if interpolator is None: interpolator = UniformUnitIntervalLinearInterpolator( path, **interpolator_kwargs ) self.interpolator = interpolator def __call__(self, t): """Interpolate path. Parameters ---------- t : array-like, shape=[n_time] Interpolation time. Returns ------- point : array-like, shape=[..., n_time, *point_shape] """ if not gs.is_array(t): t = gs.array([t]) if gs.ndim(t) == 0: t = gs.expand_dims(t, axis=0) return self.interpolator(t)