"""Scalar product of a Riemannian metric.
Define the product of a Riemannian metric with a scalar number.
Public Methods:
register_scaled_method(func_name, scaling_type)
Public classes
ScalarProductMetric
Lead author: John Harvey.
"""
from functools import wraps
import geomstats.backend as gs
import geomstats.errors
[docs]
def register_scaled_method(func_name, scaling_type):
"""Register the scaling factor of a method of a RiemannianMetric.
The ScalarProductMetric class rescales various methods of a
RiemannianMetric by the correct factor. The default behaviour is to
rescale linearly. This method allows the user to add a new method to be
rescaled according to a different rule.
Note that this method must be called before the ScalarProductMetric is
instantiated. It does not affect objects which already exist.
Parameters
----------
func_name : str
The name of a method from a RiemannianMetric object which must be
rescaled.
scaling_type : str, {'sqrt',
'linear',
'quadratic',
'inverse',
'inverse_sqrt'}
How the method should be rescaled as a function of
ScalarProductMetric.scale.
"""
_ScaledMethodsRegistry._add_scaled_method(func_name, scaling_type)
def _wrap_attr(scaling_factor, func):
@wraps(func)
def response(*args, **kwargs):
res = scaling_factor * func(*args, **kwargs)
return res
return response
class _ScaledMethodsRegistry:
"""Class to hold lists of methods and their scaling functions."""
_SQRT_LIST = ["norm", "dist", "dist_broadcast", "dist_pairwise", "diameter"]
_LINEAR_LIST = [
"metric_matrix",
"inner_product",
"inner_product_derivative_matrix",
"squared_norm",
"squared_dist",
"covariant_riemann_tensor",
]
_QUADRATIC_LIST = []
_INVERSE_LIST = [
"cometric_matrix",
"inner_coproduct",
"hamiltonian",
"sectional_curvature",
"scalar_curvature",
]
_INVERSE_SQRT_LIST = ["normalize", "random_unit_tangent_vec", "normal_basis"]
_RESERVED_NAMES = ("underlying_metric", "scale")
_SCALING_LISTS = [
_SQRT_LIST,
_LINEAR_LIST,
_QUADRATIC_LIST,
_INVERSE_LIST,
_INVERSE_SQRT_LIST,
]
_SCALING_NAMES = ["sqrt", "linear", "quadratic", "inverse", "inverse_sqrt"]
@classmethod
def _add_scaled_method(cls, func_name, scaling_type):
"""Configure ScalarProductMetric to scale an attribute.
This method should be accessed via
geomstats.geometry.scalar_product_metric.register_scaled_method
"""
scaling_dict = dict(zip(cls._SCALING_NAMES, cls._SCALING_LISTS))
for list_of_methods in cls._SCALING_LISTS:
if func_name in list_of_methods:
msg = (
f"'{func_name}' already has an assigned scaling rule "
"which cannot be changed."
)
raise ValueError(msg)
if func_name in cls._RESERVED_NAMES:
raise ValueError(f"'{func_name}' is reserved for internal use.")
if func_name.startswith("_"):
raise ValueError("Private methods cannot be rescaled")
try:
scaling_dict[scaling_type].append(func_name)
except KeyError:
msg = (
f"'{scaling_type}' is not an admissible value. Please "
"provide one of 'sqrt', 'linear', 'quadratic', "
"'inverse', 'inverse_sqrt'."
)
raise ValueError(msg)
@classmethod
def _get_scaling_factor(cls, func_name, scale):
if func_name in cls._SQRT_LIST:
return gs.sqrt(scale)
if func_name in cls._LINEAR_LIST:
return scale
if func_name in cls._QUADRATIC_LIST:
return gs.power(scale, 2)
if func_name in cls._INVERSE_LIST:
return 1.0 / scale
if func_name in cls._INVERSE_SQRT_LIST:
return 1.0 / gs.sqrt(scale)
return None
[docs]
class ScalarProductMetric:
"""Class for scalar products of Riemannian and pseudo-Riemannian metrics.
This class multiplies the (0,2) metric tensor 'space.metric' by a
scalar 'scaling_factor'. Note that this does not scale distances by
'scaling_factor'. That would require multiplication by the square of the
scalar.
The `space` is not automatically equipped with the `ScalarProductMetric`.
An object of this type can also be instantiated by the expression
scaling_factor * space.metric.
This class acts as a wrapper for the underlying Riemannian metric. All
public attributes apart from 'underlying_metric' and 'scaling_factor' are
loaded from the underlying metric at initialization and rescaled by the
appropriate factor. Changes to the underlying metric at runtime will not
affect the attributes of this object.
One exception to this is when the 'underlying_metric' is itself of type
ScalarProductMetric. In this case, rather than wrapping the wrapper, the
'underlying_metric' of the first ScalarProductMetric object is wrapped a
second time with a new 'scaling_factor'.
Parameters
----------
space : Manifold or ComplexManifold
A manifold equipped with a metric which is being scaled.
scale : float
The value by which to scale the metric. Note that this rescales the
(0,2) metric tensor, so distances are rescaled by the square root of
this.
"""
def __init__(self, space, scale):
"""Load all attributes from the underlying metric."""
geomstats.errors.check_positive(scale, "scale")
if not hasattr(space, "metric"):
raise TypeError("The variable 'space' must be equipped with a metric.")
self._space = space
if isinstance(space.metric, ScalarProductMetric):
self.underlying_metric = space.metric.underlying_metric
self.scale = scale * space.metric.scale
else:
self.underlying_metric = space.metric
self.scale = scale
for attr_name in dir(self.underlying_metric):
if (
attr_name.startswith("_")
or attr_name in _ScaledMethodsRegistry._RESERVED_NAMES
):
continue
attr = getattr(self.underlying_metric, attr_name)
if not callable(attr):
try:
setattr(self, attr_name, attr)
except AttributeError as ex:
if not isinstance(
getattr(type(self.underlying_metric), attr_name, None),
property,
):
raise ex
else:
scale = _ScaledMethodsRegistry._get_scaling_factor(
attr_name, self.scale
)
method = attr if scale is None else _wrap_attr(scale, attr)
setattr(self, attr_name, method)
def __mul__(self, scalar):
"""Multiply the metric by a scalar.
This method multiplies the (0,2) metric tensor by a scalar. Note that
this does not scale distances by the scalar. That would require
multiplication by the square of the scalar.
Parameters
----------
scalar : float
The number by which to multiply the metric.
Returns
-------
metric : ScalarProductMetric
The metric multiplied by the scalar
"""
if not isinstance(scalar, float):
return NotImplemented
if self != self._space.metric:
raise ValueError(
"A space must be equipped with this metric before it is scaled."
)
return ScalarProductMetric(self._space, scalar)
def __rmul__(self, scalar):
"""Multiply the metric by a scalar.
This method multiplies the (0,2) metric tensor by a scalar. Note that
this does not scale distances by the scalar. That would require
multiplication by the square of the scalar.
Parameters
----------
scalar : float
The number by which to multiply the metric.
Returns
-------
metric : ScalarProductMetric
The metric multiplied by the scalar.
"""
return self * scalar