Source code for geomstats.test.utils
from abc import ABC
[docs]
class PointTransformer(ABC):
[docs]
def transform_point(self, point):
raise NotImplementedError("`transform_point` not implemented")
[docs]
def transform_tangent_vec(self, tangent_vec, base_point):
raise NotImplementedError("`transform_tangent_vec` not implemented")
[docs]
def inverse_transform_point(self, other_point):
raise NotImplementedError("`inverse_transform_point` not implemented")
[docs]
def inverse_transform_tangent_vec(self, other_tangent_vec, other_base_point):
raise NotImplementedError("`inverse_transform_tangent_vec` not implemented")
[docs]
class IdentityPointTransformer(PointTransformer):
[docs]
def inverse_transform_tangent_vec(self, other_tangent_vec, other_base_point):
return other_tangent_vec
[docs]
class PointTransformerFromDiffeo(PointTransformer):
def __init__(self, diffeo):
self.diffeo = diffeo
[docs]
def transform_tangent_vec(self, tangent_vec, base_point):
return self.diffeo.tangent(tangent_vec, base_point)
[docs]
def inverse_transform_tangent_vec(self, other_tangent_vec, other_base_point):
return self.diffeo.inverse_tangent(other_tangent_vec, other_base_point)