Source code for geomstats.test.utils

from abc import ABC


[docs] class PointTransformer(ABC):
[docs] def transform_point(self, point): raise NotImplementedError("`transform_point` not implemented")
[docs] def transform_tangent_vec(self, tangent_vec, base_point): raise NotImplementedError("`transform_tangent_vec` not implemented")
[docs] def inverse_transform_point(self, other_point): raise NotImplementedError("`inverse_transform_point` not implemented")
[docs] def inverse_transform_tangent_vec(self, other_tangent_vec, other_base_point): raise NotImplementedError("`inverse_transform_tangent_vec` not implemented")
[docs] class IdentityPointTransformer(PointTransformer):
[docs] def transform_point(self, point): return point
[docs] def transform_tangent_vec(self, tangent_vec, base_point): return tangent_vec
[docs] def inverse_transform_point(self, other_point): return other_point
[docs] def inverse_transform_tangent_vec(self, other_tangent_vec, other_base_point): return other_tangent_vec
[docs] class PointTransformerFromDiffeo(PointTransformer): def __init__(self, diffeo): self.diffeo = diffeo
[docs] def transform_point(self, point): return self.diffeo(point)
[docs] def transform_tangent_vec(self, tangent_vec, base_point): return self.diffeo.tangent(tangent_vec, base_point)
[docs] def inverse_transform_point(self, other_point): return self.diffeo.inverse(other_point)
[docs] def inverse_transform_tangent_vec(self, other_tangent_vec, other_base_point): return self.diffeo.inverse_tangent(other_tangent_vec, other_base_point)