# Source code for geomstats.geometry.stratified.spider

"""Class for the spider.

Lead authors: Anna Calissano & Jonas Lueg
"""

import geomstats.backend as gs
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.stratified.point_set import (
Point,
PointSet,
PointSetMetric,
_vectorize_point,
)

[docs]
class SpiderPoint(Point):
r"""Class for points of the Spider.

A point in the Spider is :math:(s,c) \in \mathbb{N} \times \mathbb{R}.

Parameters
----------
stratum : int
The stratum, an integer indicating the stratum the point lies in.
If zero, then the point is on the origin.
stratum_coord : float
A positive number, the coordinate of the point. It must be zero if and
only if the stratum is zero, i.e. the origin.
"""

def __init__(self, stratum, stratum_coord):
super().__init__()
if stratum == 0 and stratum_coord != 0:
raise ValueError("If the stratum is zero, x must be zero.")
self.stratum = stratum
self.stratum_coord = stratum_coord

def __repr__(self):
"""Return a readable representation of the instance."""
return f"r{self.stratum}: {self.stratum_coord}"

def __hash__(self):
"""Return the hash of the instance."""
return hash((self.stratum, self.stratum_coord))

def __eq__(self, other):
"""Compare two points."""
return (
self.stratum == other.stratum
and abs(self.stratum_coord - other.stratum_coord) < gs.atol
)

[docs]
def to_array(self):
"""Return the hash of the instance."""
return gs.array([self.stratum, self.stratum_coord])

[docs]
class Spider(PointSet):
r"""Spider: a set of rays attached to the origin.

The k-spider consists of k copies of the positive real line
:math:\mathbb{R}_{\geq 0} glued together at the origin [Feragen2020]_.

Parameters
----------
n_rays : int
Number of rays to attach to the origin.
Note that zero counts as the origin not as a ray.

References
----------
.. [Feragen2020]  Feragen, Aasa, and Tom Nye. "Statistics on stratified spaces."
Riemannian Geometric Statistics in Medical Image Analysis.
"""

def __init__(self, n_rays, equip=True):
super().__init__(equip=equip)
self.n_rays = n_rays

[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return SpiderMetric

[docs]
def random_point(self, n_samples=1):
r"""Compute a random point of the spider set.

Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.

Returns
-------
samples : list of SpiderPoint, shape=[...]
List of SpiderPoints randomly sampled from the Spider.
"""
if self.n_rays != 0:
s = gs.random.randint(low=0, high=self.n_rays, size=n_samples)
x = gs.abs(gs.random.normal(loc=10, scale=1, size=n_samples))
x[s == 0] = 0
return [
SpiderPoint(stratum=s[k], stratum_coord=x[k]) for k in range(n_samples)
]
return [SpiderPoint(stratum=0, stratum_coord=0)] * n_samples

[docs]
@_vectorize_point((1, "point"))
def belongs(self, point):
r"""Check if a random point belongs to the spider set.

Parameters
----------
point : SpiderPoint or list of SpiderPoint, shape=[...]
Point to be checked.

Returns
-------
belongs : array-like, shape=[...]
Boolean denoting if the SpiderPoint belongs to the Spider Set.
"""
results = []
for single_point in point:
results += [
self._coord_check(single_point)
and self._n_rays_check(single_point)
and type(single_point) is SpiderPoint
]
return gs.array(results)

def _n_rays_check(self, single_point):
r"""Check if a random point has the correct number of rays.

Parameters
----------
single_point : SpiderPoint
Point to be checked.

Returns
-------
belongs : boolean
Boolean denoting if the point has a ray in the rays set.
"""
if single_point.stratum not in list(range(self.n_rays + 1)):
return False
return True

@staticmethod
def _coord_check(single_point):
r"""Check if a random point has the correct length.

Parameters
----------
single_point : SpiderPoint
Point to be checked.

Returns
-------
belongs : boolean
Boolean denoting if the point has a positive length when on non-zero ray.
"""
if single_point.stratum != 0 and single_point.stratum_coord <= 0:
return False
return True

[docs]
@_vectorize_point((1, "point"))
def set_to_array(self, point):
r"""Turn a point into an array compatible with the dimension of the space.

Parameters
----------
point : SpiderPoint or list of SpiderPoint, shape=[...]
Points to be checked.

Returns
-------
point_array : array-like, shape=[...,n_rays]
An array with the stratum_coord parameter in the stratum position.
"""
point_to_array = gs.zeros((len(point), self.n_rays))
for i, pt in enumerate(point):
point_to_array[i, pt.stratum - 1] = pt.stratum_coord
return point_to_array

[docs]
class SpiderMetric(PointSetMetric):
"""Geometry on the Spider, induced by the rays Geometry."""

def __init__(self, space, ray_metric=None):
super().__init__(space=space)
if ray_metric is None:
ray_metric = Euclidean(dim=1, equip=True).metric
self.ray_metric = ray_metric

@property
def n_rays(self):
"""Get number of rays."""
return self._space.n_rays

[docs]
@_vectorize_point((1, "a"), (2, "b"))
def dist(self, point_a, point_b):
"""Compute the distance between two points on the Spider using the ray geometry.

The spider metric is the metric in each ray extended to the Spider:
given two points x, y on different rays, d(x, y) = d(x, 0) + d(0, y).

Parameters
----------
point_a : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.
point_b : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.

Returns
-------
point_array : array-like, shape=[...]
An array with the distance.
"""

result = []
for point_a_, point_b_ in zip(point_a, point_b):
if (
point_a_.stratum == point_b_.stratum
or point_a_.stratum == 0
or point_b_.stratum == 0
):
result += [
self.ray_metric.norm(
gs.array([point_a_.stratum_coord - point_b_.stratum_coord])
)
]
else:
result += [point_a_.stratum_coord + point_b_.stratum_coord]
return gs.array(result) if len(result) != 1 else result[0]

[docs]
@_vectorize_point((1, "initial_point"), (2, "end_point"))
def geodesic(self, initial_point, end_point):
"""Return the geodesic between two lists of Spider points.

Parameters
----------
initial_point : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.
end_point : SpiderPoint or list of SpiderPoint, shape=[...]
Point in the Spider.

Returns
-------
geo : function
Return a vectorized geodesic function.
"""

def _vec(t, fncs):
if len(fncs) == 1:
return fncs[0](t)

return [fnc(t) for fnc in fncs]

fncs = [
self._point_geodesic(pt_a, pt_b)
for (pt_a, pt_b) in zip(initial_point, end_point)
]
return lambda t: _vec(t, fncs=fncs)

def _point_geodesic(self, initial_point, end_point):
"""Compute the distance between two Spider points.

Parameters
----------
initial_point : SpiderPoint
Point in the Spider.
end_point : SpiderPoint
Point in the Spider.

Returns
-------
geo: function
Geodesic between two Spider Points.
"""
if (
initial_point.stratum == end_point.stratum
or initial_point.stratum == 0
or end_point.stratum == 0
):
s = gs.maximum(initial_point.stratum, end_point.stratum)

def ray_geo(t):
g = self.ray_metric.geodesic(
initial_point=gs.array([initial_point.stratum_coord]),
end_point=gs.array([end_point.stratum_coord]),
)

x = g(t)
return [
SpiderPoint(stratum=s if xx[0] else 0, stratum_coord=xx[0])
for xx in x
]

return ray_geo

def ray_geo(t):
g = self.ray_metric.geodesic(
initial_point=gs.array([-initial_point.stratum_coord]),
end_point=gs.array([end_point.stratum_coord]),
)
x = g(t)
return [
SpiderPoint(stratum=initial_point.stratum, stratum_coord=-xx[0])
if xx < 0.0
else SpiderPoint(stratum=end_point.stratum, stratum_coord=xx[0])
for xx in x
]

return ray_geo