"""Geodesic solvers implementation."""
import math
from abc import ABC, abstractmethod
import geomstats.backend as gs
from geomstats.numerics.bvp import ScipySolveBVP
from geomstats.numerics.ivp import GSIVPIntegrator
from geomstats.numerics.optimizers import ScipyMinimize
from geomstats.vectorization import get_batch_shape
[docs]
class ExpSolver(ABC):
"""Abstract class for geodesic initial value problem solvers."""
[docs]
@abstractmethod
def exp(self, space, tangent_vec, base_point):
"""Exponential map.
Parameters
----------
space : Manifold
Equipped manifold.
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
end_point : array-like, shape=[..., dim]
Point on the manifold.
"""
[docs]
@abstractmethod
def geodesic_ivp(self, space, tangent_vec, base_point):
"""Geodesic curve for initial value problem.
Parameters
----------
space : Manifold
Equipped manifold.
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
[docs]
class ExpODESolver(ExpSolver):
"""Geodesic initial value problem solver.
Parameters
----------
integrator : ODEIVPSolver
Instance of ODEIVP integrator.
"""
def __init__(self, integrator=None):
if integrator is None:
integrator = GSIVPIntegrator()
self.integrator = integrator
def _solve(self, space, tangent_vec, base_point, t_eval=None):
batch_shape = get_batch_shape(space.point_ndim, base_point, tangent_vec)
base_point = gs.broadcast_to(base_point, tangent_vec.shape)
if self.integrator.state_is_raveled:
if space.point_ndim > 1:
dim_vec = math.prod(space.shape)
batch_shape_ = (-1,) if batch_shape else ()
base_point = gs.reshape(base_point, batch_shape_ + (dim_vec,))
tangent_vec = gs.reshape(tangent_vec, batch_shape_ + (dim_vec,))
initial_state = gs.hstack([base_point, tangent_vec])
else:
initial_state = gs.stack([base_point, tangent_vec])
force = self._get_force(space)
if t_eval is None:
result = self.integrator.integrate(force, initial_state)
else:
result = self.integrator.integrate_t(force, initial_state, t_eval)
result.batch_shape = batch_shape
return result
[docs]
def exp(self, space, tangent_vec, base_point):
"""Exponential map.
Parameters
----------
space : Manifold
Equipped manifold.
tangent_vec : array-like, shape=[..., *space.shape]
Tangent vector at the base point.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
Returns
-------
end_point : array-like, shape=[..., *space.shape]
Point on the manifold.
"""
result = self._solve(space, tangent_vec, base_point)
return self._simplify_exp_result(result, space)
[docs]
def geodesic_ivp(self, space, tangent_vec, base_point):
"""Geodesic curve for initial value problem.
Parameters
----------
space : Manifold
Equipped manifold.
tangent_vec : array-like, shape=[..., *space.shape]
Tangent vector at the base point.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
base_point = gs.broadcast_to(base_point, tangent_vec.shape)
def path(t):
"""Time parametrized geodesic curve.
Parameters
----------
t : float or array-like, shape=[n_times,]
Returns
-------
geodesic_points : array-like, shape=[..., n_times, *space.shape]
Geodesic points evaluated at t.
"""
if not gs.is_array(t):
t = gs.array([t])
if gs.ndim(t) == 0:
t = gs.expand_dims(t, axis=0)
result = self._solve(space, tangent_vec, base_point, t_eval=t)
return self._simplify_result_t(result, space)
return path
def _get_force(self, space):
if self.integrator.state_is_raveled:
force_ = lambda state, t: self._force_raveled_state(state, t, space=space)
else:
force_ = lambda state, t: self._force_unraveled_state(state, t, space=space)
if self.integrator.tfirst:
return lambda t, state: force_(state, t)
return force_
def _force_raveled_state(self, raveled_initial_state, _, space):
# input: (n,)
# assumes unvectorize
state = gs.reshape(raveled_initial_state, (2,) + space.shape)
eq = space.metric.geodesic_equation(state, _)
return gs.flatten(eq)
def _force_unraveled_state(self, initial_state, _, space):
return space.metric.geodesic_equation(initial_state, _)
def _simplify_exp_result(self, result, space):
y = result.get_last_y()
if self.integrator.state_is_raveled:
dim_vec = math.prod(space.shape)
exp = y[..., :dim_vec]
if space.point_ndim > 1:
return gs.reshape(exp, result.batch_shape + space.shape)
return exp
return y[0]
def _simplify_result_t(self, result, space):
# assumes several t
y = result.y
if self.integrator.state_is_raveled:
dim_vec = math.prod(space.shape)
y = y[..., :dim_vec]
if space.point_ndim > 1:
y = gs.reshape(y, y.shape[:-1] + space.shape)
if result.batch_shape:
return gs.moveaxis(y, 0, 1)
return y
y = y[:, 0, :, ...]
if result.batch_shape:
return gs.moveaxis(y, 1, 0)
return y
[docs]
class LogSolver(ABC):
"""Abstract class for geodesic boundary value problem solvers."""
[docs]
@abstractmethod
def log(self, space, point, base_point):
"""Logarithm map.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
"""
[docs]
@abstractmethod
def geodesic_bvp(self, space, point, base_point):
"""Geodesic curve for boundary value problem.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
class _GeodesicBVPFromExpMixins:
"""Provides method to get geodesic given exp."""
def _geodesic_bvp_single(self, space, t, tangent_vec, base_point):
idx = "ijk"[: space.point_ndim]
tangent_vec_ = gs.einsum(f"...,...{idx}->...{idx}", t, tangent_vec)
return space.metric.exp(tangent_vec_, base_point)
def geodesic_bvp(self, space, point, base_point):
"""Geodesic curve for boundary value problem.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
tangent_vec = self.log(space, point, base_point)
is_batch = tangent_vec.ndim > space.point_ndim
if base_point.ndim < tangent_vec.ndim:
base_point = gs.broadcast_to(base_point, tangent_vec.shape)
def path(t):
"""Time parametrized geodesic curve.
Parameters
----------
t : float or array-like, shape=[n_times,]
Returns
-------
geodesic_points : array-like, shape=[..., n_times, dim]
Geodesic points evaluated at t.
"""
if not gs.is_array(t):
t = gs.array([t])
if gs.ndim(t) == 0:
t = gs.expand_dims(t, axis=0)
if not is_batch:
return self._geodesic_bvp_single(space, t, tangent_vec, base_point)
return gs.stack(
[
self._geodesic_bvp_single(space, t, tangent_vec_, base_point_)
for tangent_vec_, base_point_ in zip(tangent_vec, base_point)
]
)
return path
class _LogBatchMixins:
"""Provides method to compute log for multiples point."""
@abstractmethod
def _log_single(self, space, point, base_point):
"""Logarithm map.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[dim]
Point on the manifold.
base_point : array-like, shape=[dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[dim]
Tangent vector at the base point.
"""
def log(self, space, point, base_point):
"""Logarithm map.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
"""
# assumes inability to properly vectorize
if point.ndim != base_point.ndim:
point, base_point = gs.broadcast_arrays(point, base_point)
is_batch = point.ndim > space.point_ndim
if not is_batch:
return self._log_single(space, point, base_point)
return gs.stack(
[
self._log_single(space, point_, base_point_)
for point_, base_point_ in zip(point, base_point)
]
)
[docs]
class LogShootingSolver:
"""Geodesic boundary value problem solver using shooting.
Parameters
----------
optimizer : ScipyMinimize
Instance of ScipyMinimize.
initialization : callable
Function to provide initial solution. `f(space, point, base_point)`.
Defaults to linear initialization.
flatten : bool
If True, the optimization problem is solved together for all the points.
"""
def __new__(cls, optimizer=None, initialization=None, flatten=True):
"""Instantiate a log shooting solver."""
if flatten:
return _LogShootingSolverFlatten(
optimizer=optimizer,
initialization=initialization,
)
return _LogShootingSolverUnflatten(
optimizer=optimizer,
initialization=initialization,
)
class _LogShootingSolverFlatten(_GeodesicBVPFromExpMixins, LogSolver):
def __init__(self, optimizer=None, initialization=None):
if optimizer is None:
optimizer = ScipyMinimize(jac="autodiff")
if initialization is None:
initialization = self._default_initialization
self.optimizer = optimizer
self.initialization = initialization
def _default_initialization(self, space, point, base_point):
return gs.flatten(point - base_point)
def _objective(self, velocity, space, point, base_point):
velocity = gs.reshape(velocity, base_point.shape)
delta = space.metric.exp(velocity, base_point) - point
return gs.sum(delta**2)
def log(self, space, point, base_point):
"""Logarithm map.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[..., *space.shape]
Point on the manifold.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., *space.shape]
Tangent vector at the base point.
"""
if point.ndim != base_point.ndim:
point, base_point = gs.broadcast_arrays(point, base_point)
objective = lambda velocity: self._objective(velocity, space, point, base_point)
init_tangent_vec = self.initialization(space, point, base_point)
res = self.optimizer.minimize(objective, init_tangent_vec)
tangent_vec = gs.reshape(res.x, base_point.shape)
return tangent_vec
class _LogShootingSolverUnflatten(
_LogBatchMixins, _GeodesicBVPFromExpMixins, LogSolver
):
def __init__(self, optimizer=None, initialization=None):
if optimizer is None:
optimizer = ScipyMinimize(jac="autodiff")
if initialization is None:
initialization = self._default_initialization
self.optimizer = optimizer
self.initialization = initialization
def _default_initialization(self, space, point, base_point):
return point - base_point
def _objective(self, velocity, space, point, base_point):
if space.point_ndim > 1:
velocity = gs.reshape(velocity, space.shape)
delta = space.metric.exp(velocity, base_point) - point
return gs.sum(delta**2)
def _log_single(self, space, point, base_point):
"""Logarithm map.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[*space.shape]
Point on the manifold.
base_point : array-like, shape=[*space.shape]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[*space.shape]
Tangent vector at the base point.
"""
objective = lambda velocity: self._objective(velocity, space, point, base_point)
init_tangent_vec = self.initialization(space, point, base_point)
res = self.optimizer.minimize(objective, gs.flatten(init_tangent_vec))
if space.point_ndim > 1:
return gs.reshape(res.x, space.shape)
return res.x
[docs]
class LogODESolver(_LogBatchMixins, LogSolver):
"""Geodesic boundary value problem using an ODE solver.
Parameters
----------
n_nodes : Number of mesh nodes.
integrator : ScipySolveBVP
Instance of ScipySolveBVP.
initialization : callable
Function to provide initial solution. `f(space, point, base_point)`.
Defaults to linear initialization.
"""
def __init__(self, n_nodes=10, integrator=None, initialization=None, use_jac=True):
if integrator is None:
integrator = ScipySolveBVP()
if initialization is None:
initialization = self._default_initialization
self.n_nodes = n_nodes
self.integrator = integrator
self.initialization = initialization
self.use_jac = use_jac
self.grid = self._create_grid()
def _create_grid(self):
return gs.linspace(0.0, 1.0, num=self.n_nodes)
def _default_initialization(self, space, point, base_point):
point_0, point_1 = base_point, point
pos_init = gs.transpose(gs.linspace(point_0, point_1, self.n_nodes))
vel_init = self.n_nodes * (pos_init[:, 1:] - pos_init[:, :-1])
vel_init = gs.hstack([vel_init, vel_init[:, [-2]]])
return gs.vstack([pos_init, vel_init])
def _boundary_condition(self, state_0, state_1, space, point_0, point_1):
pos_0 = state_0[: space.dim]
pos_1 = state_1[: space.dim]
return gs.hstack((pos_0 - point_0, pos_1 - point_1))
def _bvp(self, _, raveled_state, space):
"""Boundary value problem.
Parameters
----------
_ : float
Unused.
raveled_state : array-like, shape=[2*dim, n_nodes]
Vector of state variables (position and speed).
Returns
-------
sol : array-like, shape=[2*dim, n_nodes]
"""
state = gs.moveaxis(gs.reshape(raveled_state, (2, space.dim, -1)), -2, -1)
eq = space.metric.geodesic_equation(state, _)
return gs.reshape(gs.moveaxis(eq, -2, -1), (2 * space.dim, -1))
def _jacobian(self, _, raveled_state, space):
"""Jacobian of boundary value problem.
Parameters
----------
_ : float
Unused.
raveled_state : array-like, shape=[2*dim, n_nodes]
Vector of state variables (position and speed).
Returns
-------
jac : array-like, shape=[dim, dim, n_nodes]
"""
dim = space.dim
n_nodes = raveled_state.shape[-1]
position, velocity = raveled_state[:dim], raveled_state[dim:]
dgamma = space.metric.jacobian_christoffels(gs.transpose(position))
df_dposition = -gs.einsum(
"j...,...ijkl,k...->il...", velocity, dgamma, velocity
)
gamma = space.metric.christoffels(gs.transpose(position))
df_dvelocity = -2 * gs.einsum("...ijk,k...->ij...", gamma, velocity)
jac_nw = gs.zeros((dim, dim, raveled_state.shape[1]))
jac_ne = gs.squeeze(gs.transpose(gs.tile(gs.eye(dim), (n_nodes, 1, 1))))
jac_sw = df_dposition
jac_se = df_dvelocity
jac = gs.concatenate(
(
gs.concatenate((jac_nw, jac_ne), axis=1),
gs.concatenate((jac_sw, jac_se), axis=1),
),
axis=0,
)
return jac
def _solve(self, space, point, base_point):
bvp = lambda t, state: self._bvp(t, state, space)
bc = lambda state_0, state_1: self._boundary_condition(
state_0, state_1, space, base_point, point
)
jacobian = None
if self.use_jac:
jacobian = lambda t, state: self._jacobian(t, state, space=space)
y = self.initialization(space, point, base_point)
return self.integrator.integrate(bvp, bc, self.grid, y, fun_jac=jacobian)
def _log_single(self, space, point, base_point):
res = self._solve(space, point, base_point)
return self._simplify_log_result(res, space)
[docs]
def geodesic_bvp(self, space, point, base_point):
"""Geodesic curve for boundary value problem.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`. 0 <= t <= 1.
"""
if point.ndim != base_point.ndim:
point, base_point = gs.broadcast_arrays(point, base_point)
is_batch = point.ndim > space.point_ndim
if not is_batch:
result = self._solve(space, point, base_point)
else:
results = [
self._solve(space, point_, base_point_)
for point_, base_point_ in zip(point, base_point)
]
def path(t):
"""Time parametrized geodesic curve.
Parameters
----------
t : float or array-like, shape=[n_times,]
Returns
-------
geodesic_points : array-like, shape=[..., n_times, dim]
Geodesic points evaluated at t.
"""
if not gs.is_array(t):
t = gs.array([t])
if gs.ndim(t) == 0:
t = gs.expand_dims(t, axis=0)
if not is_batch:
return self._simplify_result_t(result.sol(t), space)
return gs.array(
[self._simplify_result_t(result.sol(t), space) for result in results]
)
return path
def _simplify_log_result(self, result, space):
_, tangent_vec = gs.reshape(gs.transpose(result.y)[0], (2, space.dim))
return tangent_vec
def _simplify_result_t(self, result, space):
return gs.transpose(result[: space.dim, :])