"""Manifold module.
In other words, a topological space that locally resembles
Euclidean space near each point.
Lead author: Nina Miolane.
"""
import abc
import inspect
import types
import geomstats.backend as gs
import geomstats.errors
from geomstats.geometry.fiber_bundle import FiberBundle
from geomstats.geometry.quotient_metric import QuotientMetric
[docs]
class Manifold(abc.ABC):
r"""Class for manifolds.
Parameters
----------
dim : int
Dimension of the manifold.
shape : tuple of int
Shape of one element of the manifold.
Optional, default : None.
intrinsic : bool
Coordinate type.
equip : bool
If True, equip space with default metric.
Attributes
----------
point_ndim : int
Dimension of point array.
"""
def __init__(
self,
dim,
shape,
intrinsic=True,
equip=True,
):
geomstats.errors.check_integer(dim, "dim")
if not isinstance(shape, tuple):
raise ValueError("Expected a tuple for the shape argument.")
if not isinstance(equip, bool):
raise TypeError("Expected a boolean value for 'equip' argument.")
self.dim = dim
self.shape = shape
self.intrinsic = intrinsic
self.point_ndim = len(self.shape)
if equip:
self.equip_with_metric()
[docs]
def equip_with_metric(self, Metric=None, **metric_kwargs):
"""Equip manifold with a Riemannian metric.
Parameters
----------
Metric : RiemannianMetric object or instance or ScalarProductMetric instance
If None, default metric will be used.
"""
if Metric is None:
out = self.default_metric()
if isinstance(out, tuple):
Metric, kwargs = out
kwargs.update(metric_kwargs)
metric_kwargs = kwargs
else:
Metric = out
if inspect.isclass(Metric):
self.metric = Metric(self, **metric_kwargs)
else:
if self.metric._space is not self:
raise ValueError(
"Cannot equip space with metric instantiated with another space."
)
self.metric = Metric
return self
[docs]
def equip_with_group_action(self, group_action):
"""Equip manifold with group action.
Parameters
----------
group_action : str
Group action.
"""
self.group_action = group_action
return self
[docs]
def equip_with_quotient(self):
"""Equip manifold with quotient structure.
Creates attributes `quotient` and `fiber_bundle` or `aligner` (
`aligner` is used in quotient contexts where the notion
of fiber bundle is not defined.).
Returns
-------
quotient : Manifold or None
Quotient space equipped with a quotient metric.
"""
if not _QuotientStructureRegistry.has_quotient(self):
raise ValueError("No quotient structure defined for this manifold.")
FiberBundle_, QuotientMetric_ = (
_QuotientStructureRegistry.get_fiber_bundle_and_quotient_metric(
self,
)
)
fiber_bundle = FiberBundle_(total_space=self)
if hasattr(fiber_bundle, "riemannian_submersion"):
self.fiber_bundle = fiber_bundle
else:
self.aligner = fiber_bundle
if QuotientMetric_ is None:
return
self.quotient = self.new(equip=False)
self.quotient.equip_with_metric(QuotientMetric_, total_space=self)
return self.quotient
[docs]
@abc.abstractmethod
def belongs(self, point, atol=gs.atol):
"""Evaluate if a point belongs to the manifold.
Parameters
----------
point : array-like, shape=[..., *point_shape]
Point to evaluate.
atol : float
Absolute tolerance.
Optional, default: backend atol.
Returns
-------
belongs : array-like, shape=[...,]
Boolean evaluating if point belongs to the manifold.
"""
[docs]
@abc.abstractmethod
def is_tangent(self, vector, base_point=None, atol=gs.atol):
"""Check whether the vector is tangent at base_point.
Parameters
----------
vector : array-like, shape=[..., *point_shape]
Vector.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
atol : float
Absolute tolerance.
Optional, default: backend atol.
Returns
-------
is_tangent : bool
Boolean denoting if vector is a tangent vector at the base point.
"""
[docs]
@abc.abstractmethod
def to_tangent(self, vector, base_point=None):
"""Project a vector to a tangent space of the manifold.
Parameters
----------
vector : array-like, shape=[..., *point_shape]
Vector.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., *point_shape]
Tangent vector at base point.
"""
[docs]
@abc.abstractmethod
def random_point(self, n_samples=1, bound=1.0):
"""Sample random points on the manifold according to some distribution.
If the manifold is compact, preferably a uniform distribution will be used.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound : float
Bound of the interval in which to sample for non compact manifolds.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., *point_shape]
Points sampled on the manifold.
"""
[docs]
def regularize(self, point):
"""Regularize a point to the canonical representation for the manifold.
Parameters
----------
point : array-like, shape=[..., dim]
Point.
Returns
-------
regularized_point : array-like, shape=[..., *point_shape]
Regularized point.
"""
return gs.copy(point)
[docs]
def random_tangent_vec(self, base_point=None, n_samples=1):
"""Generate random tangent vec.
This method is not recommended for statistical purposes,
as the tangent vectors generated are not drawn from a
distribution related to the Riemannian metric.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
base_point : array-like, shape={[n_samples, *point_shape], [*point_shape,]}
Point.
Returns
-------
tangent_vec : array-like, shape=[..., *point_shape]
Tangent vec at base point.
"""
if (
n_samples > 1
and base_point is not None
and base_point.ndim > len(self.shape)
and n_samples != len(base_point)
):
raise ValueError(
"The number of base points must be the same as the "
"number of samples, when the number of base points is different from 1."
)
batch_size = () if n_samples == 1 else (n_samples,)
return self.to_tangent(
gs.random.normal(size=batch_size + self.shape), base_point
)
[docs]
def projection(self, point):
"""Project a point to the manifold.
Parameters
----------
point: array-like, shape[..., *point_shape]
Point.
Returns
-------
point: array-like, shape[..., *point_shape]
Point.
"""
if self.intrinsic:
return gs.copy(point)
raise NotImplementedError("`projection` is not implemented yet")
class _QuotientStructureRegistry:
"""Registry for quotient structures."""
STRUCTURES = {}
@classmethod
def _as_key(self, Obj):
"""Transform an instance of a class into a key.
Parameters
----------
Obj : type or instance or str
Returns
-------
Obj : type or str
Hashable object as used to create dict keys
within STRUCTURES.
"""
if not (
inspect.isclass(Obj)
or isinstance(Obj, types.FunctionType)
or isinstance(Obj, (str, tuple))
):
return type(Obj)
return Obj
@classmethod
def has_quotient(cls, Space):
"""Check if a given type has an associated quotient structure.
Parameters
----------
Space : type or instance or str
Returns
-------
has_quotient : bool
"""
Space = cls._as_key(Space)
for Space_, _, _ in cls.STRUCTURES.keys():
if Space_ is Space:
return True
return False
@classmethod
def get_available_quotients(cls, Space, Metric=None, GroupAction=None):
"""Get available quotient structures.
Parameters
----------
Space : type or instance or str
Metric : type or instance or str
GroupAction : type or instance of str
Returns
-------
available_structures : list[tuple[type or str]]
"""
Space = cls._as_key(Space)
structures = []
if Metric is None and GroupAction is None:
for Space_, Metric_, GroupAction_ in cls.STRUCTURES.keys():
if Space_ is Space:
structures.append((Metric_, GroupAction_))
return structures
if Metric is not None and GroupAction is None:
Metric = cls._as_key(Metric)
for Space_, Metric_, GroupAction_ in cls.STRUCTURES.keys():
if Space_ is Space and Metric_ is Metric:
structures.append((GroupAction_,))
if Metric is None and GroupAction is not None:
GroupAction = cls._as_key(GroupAction)
for Space_, Metric_, GroupAction_ in cls.STRUCTURES.keys():
if Space_ is Space and GroupAction_ is GroupAction:
structures.append((Metric_,))
return structures
@classmethod
def get_fiber_bundle_and_quotient_metric(cls, Space, Metric=None, GroupAction=None):
"""Get fiber bundle and quotient metric.
Checks are done along the way. Meaningful messages with
available structures in raised errors.
Parameters
----------
Space : type or instance or str
Metric : type or instance or str
Returns
-------
FiberBundle : type
QuotientMetric : type
"""
if (Metric is None or GroupAction is None) and inspect.isclass(Space):
raise ValueError("Pass instantiated space or metric and group action info.")
if Metric is None:
Metric = getattr(Space, "metric", None)
if GroupAction is None:
GroupAction = getattr(Space, "group_action", None)
for structure, structure_name in zip(
[Metric, GroupAction], ["metric", "group_action"]
):
if structure is None:
available_structures = cls.get_available_quotients(
Space, Metric=Metric, GroupAction=GroupAction
)
structs_str = "\n\t".join(
[
", ".join(str(elem) for elem in struct)
for struct in available_structures
]
)
raise ValueError(
f"Need to equip with `{structure_name}` first. "
f"Available structures:\n\t{structs_str}"
)
Space = cls._as_key(Space)
Metric = cls._as_key(Metric)
GroupAction = cls._as_key(GroupAction)
key = (Space, Metric, GroupAction)
out = cls.STRUCTURES.get(key, None)
if out is None:
if isinstance(GroupAction, tuple):
return (
lambda *args, **kwargs: FiberBundle(*args, **kwargs, aligner=True),
QuotientMetric,
)
else:
raise ValueError(f"No mapping for key: {key}")
return out
[docs]
def register_quotient(Space, Metric, GroupAction, FiberBundle, QuotientMetric=None):
"""Register quotient structure.
Parameters
----------
Space : type or str
Metric : type or str
GroupAction : type or str
FiberBundle : type or str
QuotientMetric : type or str
"""
_QuotientStructureRegistry.STRUCTURES[(Space, Metric, GroupAction)] = (
FiberBundle,
QuotientMetric,
)