Source code for geomstats.errors
"""Checks and associated errors."""
import math
import geomstats.backend as gs
[docs]
def check_integer(n, n_name):
"""Raise an error if n is not a > 0 integer.
Parameters
----------
n : unspecified
Parameter to be tested.
n_name : string
Name of the parameter.
"""
if n is not None and not (isinstance(n, int) and n > 0) and n != math.inf:
raise ValueError(
f"{n_name} is required to be either"
" None, math.inf or a strictly positive integer,"
f" got {n}."
)
[docs]
def check_positive(param, param_name):
"""Raise an error if param is not a > 0 number.
Parameters
----------
param : unspecified
Parameter to be tested.
param_name : string
Name of the parameter.
"""
if not (
(isinstance(param, (int, float)) or (gs.is_array(param) and param.ndim == 0))
and param > 0
):
raise ValueError(f"{param_name} must be positive.")
[docs]
def check_belongs(point, manifold, atol=gs.atol):
"""Raise an error if point does not belong to the input manifold.
Parameters
----------
point : array-like
Point to be tested.
manifold : Manifold
Manifold to which the point should belong.
manifold_name : string
Name of the manifold for the error message.
"""
if not gs.all(manifold.belongs(point, atol=atol)):
raise RuntimeError(
f"Some points do not belong to manifold '{type(manifold).__name__}'"
f" of dimension {manifold.dim}."
)
[docs]
def check_parameter_accepted_values(param, param_name, accepted_values):
"""Raise an error if parameter does not belong to a set of values.
Parameters
----------
param : unspecified
Parameter to be tested.
param_name : string
Name of the parameter.
accepted_values : list
Accepted values that the parameter can take.
"""
if param not in accepted_values:
raise ValueError(
f"Parameter {param_name} needs to be in {accepted_values}, got: {param}."
)
[docs]
def check_point_shape(point, manifold, suppress_error=False):
"""Check if the shape of point does not match the shape of a manifold or metric.
If the final elements of the shape of point do not match the shape of manifold
(which may be any object with a shape attribute, such as a Riemannian metric) then
point cannot be an array of points on the manifold (or similar) and a ValueError is
raised. The error can be suppressed by setting suppress_error to True.
Parameters
----------
point : array-like
The point to check the shape of.
manifold : {Manifold, RiemannianMetric}
The object to check the point against
suppress_error : bool
Whether to suppress the ShapeError if the shapes do not match. Optional, default
is False.
Returns
-------
shapes_match : bool
Whether the shape of the point matches the shape of the manifold or metric.
Raises
------
ValueError
If the final dimensions of point are not equal to the final dimensions of
manifold.
"""
representation_type = -1 * len(manifold.shape)
shapes_match = (
point.shape[representation_type:] == manifold.shape[representation_type:]
)
if not suppress_error and not shapes_match:
shape_error_msg = (
f"The shape of {point}, which is {point.shape}, is not"
f" compatible with the shape of the {type(manifold).__name__}"
f" object, which is {manifold.shape}."
)
raise ShapeError(shape_error_msg)
return shapes_match
[docs]
class ShapeError(ValueError):
"""Raised when there is an incompatibility between shapes."""