"""Class for the spider.
Lead authors: Anna Calissano & Jonas Lueg
"""
import geomstats.backend as gs
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.stratified.point_set import (
Point,
PointSet,
PointSetMetric,
_vectorize_point,
broadcast_lists,
)
[docs]
class SpiderPoint(Point):
r"""Class for points of the Spider.
A point in the Spider is :math:`(s,c) \in \mathbb{N} \times \mathbb{R}`.
Parameters
----------
stratum : int
The stratum, an integer indicating the stratum the point lies in.
If zero, then the point is on the origin.
stratum_coord : float
A positive number, the coordinate of the point. It must be zero if and
only if the stratum is zero, i.e. the origin.
"""
def __init__(self, stratum, stratum_coord):
super().__init__()
if stratum == 0 and stratum_coord != 0:
raise ValueError("If the stratum is zero, x must be zero.")
self.stratum = stratum
self.stratum_coord = stratum_coord
def __repr__(self):
"""Return a readable representation of the instance."""
return f"r{self.stratum}: {self.stratum_coord}"
def __hash__(self):
"""Return the hash of the instance."""
return hash((self.stratum, self.stratum_coord))
def __eq__(self, other):
"""Compare two points."""
return (
self.stratum == other.stratum
and abs(self.stratum_coord - other.stratum_coord) < gs.atol
)
[docs]
def to_array(self):
"""Return the hash of the instance."""
return gs.array([self.stratum, self.stratum_coord])
[docs]
class Spider(PointSet):
r"""Spider: a set of rays attached to the origin.
The k-spider consists of k copies of the positive real line
:math:`\mathbb{R}_{\geq 0}` glued together at the origin [Feragen2020]_.
Parameters
----------
n_rays : int
Number of rays to attach to the origin.
Note that zero counts as the origin not as a ray.
References
----------
.. [Feragen2020] Feragen, Aasa, and Tom Nye. "Statistics on stratified spaces."
Riemannian Geometric Statistics in Medical Image Analysis.
Academic Press, 2020. 299-342.
"""
def __init__(self, n_rays, equip=True):
super().__init__(equip=equip)
self.n_rays = n_rays
[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return SpiderMetric
[docs]
def random_point(self, n_samples=1):
r"""Compute a random point of the spider set.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
Returns
-------
samples : list of SpiderPoint, shape=[...]
List of SpiderPoints randomly sampled from the Spider.
"""
if self.n_rays != 0:
s = gs.random.randint(low=0, high=self.n_rays, size=n_samples)
x = gs.abs(gs.random.normal(loc=10, scale=1, size=n_samples))
x[s == 0] = 0
return [
SpiderPoint(stratum=s[k], stratum_coord=x[k]) for k in range(n_samples)
]
return [SpiderPoint(stratum=0, stratum_coord=0)] * n_samples
[docs]
@_vectorize_point((1, "point"))
def belongs(self, point):
r"""Check if a random point belongs to the spider set.
Parameters
----------
point : SpiderPoint or list of SpiderPoint, shape=[...]
Point to be checked.
Returns
-------
belongs : array-like, shape=[...]
Boolean denoting if the SpiderPoint belongs to the Spider Set.
"""
results = []
for single_point in point:
results += [
self._coord_check(single_point)
and self._n_rays_check(single_point)
and type(single_point) is SpiderPoint
]
return gs.array(results)
def _n_rays_check(self, single_point):
r"""Check if a random point has the correct number of rays.
Parameters
----------
single_point : SpiderPoint
Point to be checked.
Returns
-------
belongs : boolean
Boolean denoting if the point has a ray in the rays set.
"""
if single_point.stratum not in list(range(self.n_rays + 1)):
return False
return True
@staticmethod
def _coord_check(single_point):
r"""Check if a random point has the correct length.
Parameters
----------
single_point : SpiderPoint
Point to be checked.
Returns
-------
belongs : boolean
Boolean denoting if the point has a positive length when on non-zero ray.
"""
if single_point.stratum != 0 and single_point.stratum_coord <= 0:
return False
return True
[docs]
@_vectorize_point((1, "point"))
def set_to_array(self, point):
r"""Turn a point into an array compatible with the dimension of the space.
Parameters
----------
point : SpiderPoint or list of SpiderPoint, shape=[...]
Points to be checked.
Returns
-------
point_array : array-like, shape=[...,n_rays]
An array with the stratum_coord parameter in the stratum position.
"""
point_to_array = gs.zeros((len(point), self.n_rays))
for i, pt in enumerate(point):
point_to_array[i, pt.stratum - 1] = pt.stratum_coord
return point_to_array
[docs]
class SpiderMetric(PointSetMetric):
"""Geometry on the Spider, induced by the rays Geometry."""
def __init__(self, space, ray_metric=None):
super().__init__(space=space)
if ray_metric is None:
ray_metric = Euclidean(dim=1, equip=True).metric
self.ray_metric = ray_metric
@property
def n_rays(self):
"""Get number of rays."""
return self._space.n_rays
[docs]
@_vectorize_point((1, "a"), (2, "b"))
def dist(self, point_a, point_b):
"""Compute the distance between two points on the Spider using the ray geometry.
The spider metric is the metric in each ray extended to the Spider:
given two points x, y on different rays, d(x, y) = d(x, 0) + d(0, y).
Parameters
----------
point_a : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.
point_b : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.
Returns
-------
point_array : array-like, shape=[...]
An array with the distance.
"""
point_a, point_b = broadcast_lists(point_a, point_b)
result = []
for point_a_, point_b_ in zip(point_a, point_b):
if (
point_a_.stratum == point_b_.stratum
or point_a_.stratum == 0
or point_b_.stratum == 0
):
result += [
self.ray_metric.norm(
gs.array([point_a_.stratum_coord - point_b_.stratum_coord])
)
]
else:
result += [point_a_.stratum_coord + point_b_.stratum_coord]
return gs.array(result) if len(result) != 1 else result[0]
[docs]
@_vectorize_point((1, "initial_point"), (2, "end_point"))
def geodesic(self, initial_point, end_point):
"""Return the geodesic between two lists of Spider points.
Parameters
----------
initial_point : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.
end_point : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.
Returns
-------
geo : function
Return a vectorized geodesic function.
"""
def _vec(t, fncs):
if len(fncs) == 1:
return fncs[0](t)
return [fnc(t) for fnc in fncs]
initial_point, end_point = broadcast_lists(initial_point, end_point)
fncs = [
self._point_geodesic(pt_a, pt_b)
for (pt_a, pt_b) in zip(initial_point, end_point)
]
return lambda t: _vec(t, fncs=fncs)
def _point_geodesic(self, initial_point, end_point):
"""Compute the distance between two Spider points.
Parameters
----------
initial_point : SpiderPoint
Point in the Spider.
end_point : SpiderPoint
Point in the Spider.
Returns
-------
geo: function
Geodesic between two Spider Points.
"""
if (
initial_point.stratum == end_point.stratum
or initial_point.stratum == 0
or end_point.stratum == 0
):
s = gs.maximum(initial_point.stratum, end_point.stratum)
def ray_geo(t):
g = self.ray_metric.geodesic(
initial_point=gs.array([initial_point.stratum_coord]),
end_point=gs.array([end_point.stratum_coord]),
)
x = g(t)
return [
SpiderPoint(stratum=s if xx[0] else 0, stratum_coord=xx[0])
for xx in x
]
return ray_geo
def ray_geo(t):
g = self.ray_metric.geodesic(
initial_point=gs.array([-initial_point.stratum_coord]),
end_point=gs.array([end_point.stratum_coord]),
)
x = g(t)
return [
SpiderPoint(stratum=initial_point.stratum, stratum_coord=-xx[0])
if xx < 0.0
else SpiderPoint(stratum=end_point.stratum, stratum_coord=xx[0])
for xx in x
]
return ray_geo