"""Geodesic solvers implementation."""
from abc import ABC, abstractmethod
import geomstats.backend as gs
from geomstats.numerics.bvp import ScipySolveBVP
from geomstats.numerics.ivp import GSIVPIntegrator
from geomstats.numerics.optimization import ScipyMinimize
from geomstats.numerics.path import (
UniformlySampledDiscretePath,
UniformlySampledPathEnergy,
)
from geomstats.vectorization import get_batch_shape
[docs]
class ExpSolver(ABC):
"""Abstract class for geodesic initial value problem solvers.
Parameters
----------
solves_ivp : bool
Informs if solver is able to solve for geodesic at different t.
"""
def __init__(self, solves_ivp=False):
self.solves_ivp = solves_ivp
[docs]
@abstractmethod
def exp(self, tangent_vec, base_point):
"""Exponential map.
Parameters
----------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
end_point : array-like, shape=[..., dim]
Point on the manifold.
"""
[docs]
def geodesic_ivp(self, tangent_vec, base_point):
"""Geodesic curve for initial value problem.
Parameters
----------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
raise NotImplementedError("Can't solve for geodesic at different t.")
[docs]
class ExpODESolver(ExpSolver):
"""Geodesic initial value problem solver.
Integrate geodesic equation.
Parameters
----------
space : Manifold
Equipped manifold.
integrator : ODEIVPIntegrator
Instance of ODEIVP integrator.
"""
def __init__(self, space, integrator=None):
self._space = space
super().__init__()
if integrator is None:
integrator = GSIVPIntegrator()
self._integrator = None
self.integrator = integrator
@property
def integrator(self):
"""An instance of ODEIVPIntegrator."""
return self._integrator
@integrator.setter
def integrator(self, integrator):
"""Set integrator."""
self.solves_ivp = integrator.tchosen
self._integrator = integrator
def _solve(self, tangent_vec, base_point, t_eval=None):
if base_point.ndim != tangent_vec.ndim:
base_point = gs.broadcast_to(base_point, tangent_vec.shape)
state_axis = -(self._space.point_ndim + 1)
initial_state = gs.stack([base_point, tangent_vec], axis=state_axis)
force = lambda state, _: self._space.metric.geodesic_equation(state)
if t_eval is None:
return self.integrator.integrate(force, initial_state)
return self.integrator.integrate_t(force, initial_state, t_eval)
[docs]
def exp(self, tangent_vec, base_point):
"""Exponential map.
Parameters
----------
tangent_vec : array-like, shape=[..., *space.shape]
Tangent vector at the base point.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
Returns
-------
end_point : array-like, shape=[..., *space.shape]
Point on the manifold.
"""
result = self._solve(tangent_vec, base_point)
return self._simplify_exp_result(result)
[docs]
def geodesic_ivp(self, tangent_vec, base_point):
"""Geodesic curve for initial value problem.
Parameters
----------
tangent_vec : array-like, shape=[..., *space.shape]
Tangent vector at the base point.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
if not self.solves_ivp:
raise NotImplementedError(
"Can't solve for geodesic at different t with this integrator."
)
base_point = gs.broadcast_to(base_point, tangent_vec.shape)
def path(t):
"""Time parametrized geodesic curve.
Parameters
----------
t : float or array-like, shape=[n_times,]
Returns
-------
geodesic_points : array-like, shape=[..., n_times, *space.shape]
Geodesic points evaluated at t.
"""
if not gs.is_array(t):
t = gs.array([t])
if gs.ndim(t) == 0:
t = gs.expand_dims(t, axis=0)
result = self._solve(tangent_vec, base_point, t_eval=t)
return self._simplify_result_t(result)
return path
def _simplify_exp_result(self, result):
y = result.get_last_y()
point_ndim_slc = tuple([slice(None)] * self._space.point_ndim)
return y[(..., 0) + point_ndim_slc]
def _simplify_result_t(self, result):
# assumes several t
y = result.y
point_ndim_slc = tuple([slice(None)] * self._space.point_ndim)
return y[(..., slice(None), 0) + point_ndim_slc]
[docs]
class LogSolver(ABC):
"""Abstract class for geodesic boundary value problem solvers.
Parameters
----------
solves_bvp : bool
Informs if solver is able to solve for geodesic at different t.
"""
def __init__(self, solves_bvp=False):
self.solves_bvp = solves_bvp
[docs]
@abstractmethod
def log(self, space, point, base_point):
"""Logarithm map.
Parameters
----------
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
"""
[docs]
def geodesic_bvp(self, point, base_point):
"""Geodesic curve for boundary value problem.
Parameters
----------
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
raise NotImplementedError("Can't solve for geodesic at different t.")
class _LogBatchMixins:
"""Provides method to compute log for multiples point."""
@abstractmethod
def _log_single(self, point, base_point):
"""Logarithm map.
Parameters
----------
end_point : array-like, shape=[dim]
Point on the manifold.
base_point : array-like, shape=[dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[dim]
Tangent vector at the base point.
"""
def log(self, point, base_point):
"""Logarithm map.
Parameters
----------
space : Manifold
Equipped manifold.
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., dim]
Tangent vector at the base point.
"""
# assumes inability to properly vectorize
if point.ndim != base_point.ndim:
point, base_point = gs.broadcast_arrays(point, base_point)
is_batch = point.ndim > self._space.point_ndim
if not is_batch:
return self._log_single(point, base_point)
return gs.stack(
[
self._log_single(point_, base_point_)
for point_, base_point_ in zip(point, base_point)
]
)
[docs]
class LogShootingSolver:
"""Geodesic boundary value problem solver using shooting.
Parameters
----------
space : Manifold
Equipped manifold.
optimizer : ScipyMinimize
Instance of ScipyMinimize.
initialization : callable
Function to provide initial solution. `f(point, base_point)`.
Defaults to linear initialization.
flatten : bool
If True, the optimization problem is solved together for all the batch points.
"""
def __new__(cls, space, optimizer=None, initialization=None, flatten=True):
"""Instantiate a log shooting solver."""
if flatten:
return _LogShootingSolverFlatten(
space=space,
optimizer=optimizer,
initialization=initialization,
)
return _LogShootingSolverUnflatten(
space=space,
optimizer=optimizer,
initialization=initialization,
)
class _LogShootingSolver(LogSolver, ABC):
"""Geodesic boundary value problem solver using shooting.
Parameters
----------
space : Manifold
Equipped manifold.
optimizer : ScipyMinimize
Instance of ScipyMinimize.
initialization : callable
Function to provide initial solution. `f(point, base_point)`.
Defaults to linear initialization.
"""
def __init__(self, space, optimizer=None, initialization=None):
super().__init__(solves_bvp=False)
self._space = space
if optimizer is None:
optimizer = ScipyMinimize(autodiff_jac=True)
if initialization is None:
initialization = self._default_initialization
self.optimizer = optimizer
self.initialization = initialization
def _default_initialization(self, point, base_point):
"""Linear initialization.
Parameters
----------
end_point : array-like, shape=[..., *space.shape]
Point on the manifold.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
"""
return gs.flatten(point - base_point)
def _objective(self, flat_tangent_vec, point, base_point, batch_shape):
"""Objective function.
Minimizes Euclidean distance to shooted point.
Parameters
----------
flat_tangent_vec : array-like, shape=[prod(batch_shape)*prod(space.shape)]
Flattened tangent vector.
end_point : array-like, shape=[..., *space.shape]
Point on the manifold.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
batch_shape : tuple
Batch shape.
"""
tangent_vec = gs.reshape(flat_tangent_vec, batch_shape + self._space.shape)
delta = self._space.metric.exp(tangent_vec, base_point) - point
return gs.sum(delta**2)
class _LogShootingSolverFlatten(_LogShootingSolver):
"""Geodesic boundary value problem solver using shooting.
Parameters
----------
space : Manifold
Equipped manifold.
optimizer : ScipyMinimize
Instance of ScipyMinimize.
initialization : callable
Function to provide initial solution. `f(point, base_point)`.
Defaults to linear initialization.
Notes
-----
Differs from `_LogShootingSolverUnflatten` as it always solves one
optimization problem, even in batch mode.
"""
def log(self, point, base_point):
"""Logarithm map.
Parameters
----------
end_point : array-like, shape=[..., *space.shape]
Point on the manifold.
base_point : array-like, shape=[..., *space.shape]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., *space.shape]
Tangent vector at the base point.
"""
batch_shape = get_batch_shape(self._space.point_ndim, point, base_point)
objective = lambda velocity: self._objective(
velocity, point, base_point, batch_shape
)
init_tangent_vec = self.initialization(point, base_point)
res = self.optimizer.minimize(objective, init_tangent_vec)
return gs.reshape(res.x, batch_shape + self._space.shape)
class _LogShootingSolverUnflatten(_LogBatchMixins, _LogShootingSolver):
"""Geodesic boundary value problem solver using shooting.
Parameters
----------
space : Manifold
Equipped manifold.
optimizer : ScipyMinimize
Instance of ScipyMinimize.
initialization : callable
Function to provide initial solution. `f(point, base_point)`.
Defaults to linear initialization.
Notes
-----
Differs from `_LogShootingSolverFlatten` as it solves one
optimization problem for each combination of point and base point.
"""
def _log_single(self, point, base_point):
"""Logarithm map.
Parameters
----------
end_point : array-like, shape=[*space.shape]
Point on the manifold.
base_point : array-like, shape=[*space.shape]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[*space.shape]
Tangent vector at the base point.
"""
objective = lambda velocity: self._objective(
velocity, point, base_point, batch_shape=()
)
init_tangent_vec = self.initialization(point, base_point)
res = self.optimizer.minimize(objective, init_tangent_vec)
if self._space.point_ndim > 1:
return gs.reshape(res.x, self._space.shape)
return res.x
[docs]
class LogODESolver(_LogBatchMixins, LogSolver):
"""Geodesic boundary value problem using an ODE solver.
Parameters
----------
space : Manifold
Equipped manifold.
n_nodes : Number of mesh nodes.
integrator : ScipySolveBVP
Instance of ScipySolveBVP.
initialization : callable
Function to provide initial solution. `f( point, base_point)`.
Defaults to linear initialization.
"""
def __init__(
self, space, n_nodes=10, integrator=None, initialization=None, use_jac=True
):
self._space = space
super().__init__(solves_bvp=True)
if integrator is None:
integrator = ScipySolveBVP()
if initialization is None:
initialization = self._default_initialization
self.n_nodes = n_nodes
self.integrator = integrator
self.initialization = initialization
self.use_jac = use_jac
self.grid = self._create_grid()
def _create_grid(self):
return gs.linspace(0.0, 1.0, num=self.n_nodes)
def _default_initialization(self, point, base_point):
if point.ndim == 1:
point_0, point_1 = base_point, point
else:
point_0 = gs.flatten(base_point)
point_1 = gs.flatten(point)
pos_init = gs.transpose(gs.linspace(point_0, point_1, self.n_nodes))
vel_init = self.n_nodes * (pos_init[:, 1:] - pos_init[:, :-1])
vel_init = gs.hstack([vel_init, vel_init[:, [-2]]])
return gs.vstack([pos_init, vel_init])
def _boundary_condition(self, state_0, state_1, point_0, point_1):
pos_0 = state_0[: point_0.shape[0]]
pos_1 = state_1[: point_1.shape[0]]
return gs.hstack((pos_0 - point_0, pos_1 - point_1))
def _bvp(self, raveled_state):
"""Boundary value problem.
Parameters
----------
raveled_state : array-like, shape=[2*dim, n_nodes]
Vector of state variables (position and speed).
Returns
-------
sol : array-like, shape=[2*dim, n_nodes]
"""
state = gs.moveaxis(
gs.reshape(raveled_state, (2,) + self._space.shape + (-1,)), -1, 0
)
new_state = self._space.metric.geodesic_equation(state)
return gs.moveaxis(gs.reshape(new_state, (-1, raveled_state.shape[0])), -2, -1)
def _jacobian(self, _, raveled_state):
"""Jacobian of boundary value problem.
Parameters
----------
_ : float
Unused.
raveled_state : array-like, shape=[2*dim, n_nodes]
Vector of state variables (position and speed).
Returns
-------
jac : array-like, shape=[dim, dim, n_nodes]
"""
dim = self._space.dim
n_nodes = raveled_state.shape[-1]
position, velocity = gs.transpose(raveled_state[:dim]), raveled_state[dim:]
dgamma = self._space.metric.jacobian_christoffels(position)
df_dposition = -gs.einsum(
"j...,...ijkl,k...->il...", velocity, dgamma, velocity
)
gamma = self._space.metric.christoffels(position)
df_dvelocity = -2 * gs.einsum("...ijk,k...->ij...", gamma, velocity)
jac_nw = gs.zeros((dim, dim, raveled_state.shape[1]))
jac_ne = gs.squeeze(gs.transpose(gs.tile(gs.eye(dim), (n_nodes, 1, 1))))
jac_sw = df_dposition
jac_se = df_dvelocity
jac = gs.concatenate(
(
gs.concatenate((jac_nw, jac_ne), axis=1),
gs.concatenate((jac_sw, jac_se), axis=1),
),
axis=0,
)
return jac
def _solve(self, point, base_point):
bvp = lambda t, state: self._bvp(state)
bc = lambda state_0, state_1: self._boundary_condition(
state_0, state_1, gs.flatten(base_point), gs.flatten(point)
)
jacobian = None
if self.use_jac:
jacobian = lambda t, state: self._jacobian(t, state)
y = self.initialization(point, base_point)
return self.integrator.integrate(bvp, bc, self.grid, y, fun_jac=jacobian)
def _log_single(self, point, base_point):
res = self._solve(point, base_point)
return self._simplify_log_result(res)
[docs]
def geodesic_bvp(self, point, base_point):
"""Geodesic curve for boundary value problem.
Parameters
----------
end_point : array-like, shape=[..., dim]
Point on the manifold.
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`. 0 <= t <= 1.
"""
if point.ndim != base_point.ndim:
point, base_point = gs.broadcast_arrays(point, base_point)
is_batch = point.ndim > self._space.point_ndim
if not is_batch:
result = self._solve(point, base_point)
else:
results = [
self._solve(point_, base_point_)
for point_, base_point_ in zip(point, base_point)
]
def path(t):
"""Time parametrized geodesic curve.
Parameters
----------
t : float or array-like, shape=[n_times,]
Returns
-------
geodesic_points : array-like, shape=[..., n_times, dim]
Geodesic points evaluated at t.
"""
if not gs.is_array(t):
t = gs.array([t])
if gs.ndim(t) == 0:
t = gs.expand_dims(t, axis=0)
if not is_batch:
return self._simplify_result_t(result.sol(t))
return gs.array(
[self._simplify_result_t(result.sol(t)) for result in results]
)
return path
def _simplify_log_result(self, result):
return gs.reshape(result.y[..., 0], (2,) + self._space.shape)[1]
def _simplify_result_t(self, result):
return gs.moveaxis(
gs.reshape(result[: result.shape[0] // 2, :], self._space.shape + (-1,)),
-1,
0,
)
class _DiscreteGeodesicBVPBatchMixins:
@abstractmethod
def _discrete_geodesic_bvp_single(self, point, base_point):
"""Solve boundary value problem (BVP).
Given an initial point and an end point, solve the geodesic equation
via minimizing the Riemannian path energy.
Parameters
----------
point : array-like, shape=[*point_shape]
Point on the manifold.
base_point : array-like, shape=[*point_shape]
Point on the manifold.
Returns
-------
discr_geod_path : array-like, shape=[n_nodes, *point_shape]
Discrete geodesic.
"""
def discrete_geodesic_bvp(self, point, base_point):
"""Solve boundary value problem (BVP).
Given an initial point and an end point, solve the geodesic equation
via minimizing the Riemannian path energy.
Parameters
----------
point : array-like, shape=[..., *point_shape]
Point on the manifold.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
Returns
-------
discr_geod_path : array-like, shape=[..., n_nodes, *point_shape]
Discrete geodesic.
"""
if point.ndim != base_point.ndim:
point, base_point = gs.broadcast_arrays(point, base_point)
is_batch = point.ndim > self._space.point_ndim
if not is_batch:
return self._discrete_geodesic_bvp_single(point, base_point)
return gs.stack(
[
self._discrete_geodesic_bvp_single(point_, base_point_)
for point_, base_point_ in zip(point, base_point)
]
)
[docs]
class PathBasedLogSolver(LogSolver, ABC):
"""A geodesic BVP solver based on finding a discrete geodesic path.
Parameters
----------
space : Manifold
Equipped manifold.
"""
def __init__(self, space):
self._space = space
super().__init__(solves_bvp=True)
[docs]
def path_n_nodes(self, path):
"""Get number of nodes of a path.
Parameters
----------
discr_path : array-like, shape=[..., n_nodes, *point_shape]
Discrete path.
Returns
-------
n_nodes : int
Number of path nodes.
"""
return path.shape[-self._space.point_ndim - 1]
[docs]
@abstractmethod
def discrete_geodesic_bvp(self, point, base_point):
"""Solve boundary value problem (BVP).
Given an initial point and an end point, solve the geodesic equation
via minimizing the Riemannian path energy.
Parameters
----------
point : array-like, shape=[..., *point_shape]
Point on the manifold.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
Returns
-------
discr_geod_path : array-like, shape=[..., n_nodes, *point_shape]
Discrete geodesic.
"""
[docs]
def log(self, point, base_point):
"""Logarithm map.
Parameters
----------
end_point : array-like, shape=[..., *point_shape]
Point on the manifold.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., *point_shape]
Tangent vector at the base point.
"""
discr_geod_path = self.discrete_geodesic_bvp(point, base_point)
point_ndim_slc = (slice(None),) * self._space.point_ndim
return (self.path_n_nodes(discr_geod_path) - 1) * (
discr_geod_path[(..., 1) + point_ndim_slc]
- discr_geod_path[(..., 0) + point_ndim_slc]
)
[docs]
def geodesic_bvp(self, point, base_point):
"""Geodesic curve for boundary value problem.
Parameters
----------
end_point : array-like, shape=[..., *point_shape]
Point on the manifold.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
Returns
-------
path : callable
Time parametrized geodesic curve. `f(t)`.
"""
discr_geod_path = self.discrete_geodesic_bvp(point, base_point)
return UniformlySampledDiscretePath(
discr_geod_path, point_ndim=self._space.point_ndim
)
[docs]
class PathStraightening(_DiscreteGeodesicBVPBatchMixins, PathBasedLogSolver):
"""Geodesic boundary value problem with path-straightening.
Parameters
----------
space : Manifold
Equipped manifold.
path_energy : callable
Method to compute Riemannian path energy.
n_nodes : int
Number of path discretization points.
optimizer : ScipyMinimize
An optimizer to solve path energy minimization problem.
initialization : callable
A method to get initial guess for optimization.
symmetric : bool
If to use a symmetrized version of the energy.
References
----------
.. [HSKCB2022] "Elastic shape analysis of surfaces with second-order
Sobolev metrics: a comprehensive numerical framework".
arXiv:2204.04238 [cs.CV], 25 Sep 2022
"""
def __init__(
self,
space,
path_energy=None,
n_nodes=100,
optimizer=None,
initialization=None,
symmetric=False,
):
super().__init__(space)
if optimizer is None:
optimizer = ScipyMinimize(
method="L-BFGS-B",
autodiff_jac=True,
options={"disp": False},
)
if path_energy is None:
path_energy = UniformlySampledPathEnergy(space)
if initialization is None:
initialization = self._default_initialization
self.n_nodes = n_nodes
self.optimizer = optimizer
self.path_energy = path_energy
self.initialization = initialization
self.symmetric = symmetric
def _default_initialization(self, point, base_point):
"""Linear initialization.
Parameters
----------
point : array-like, shape=[..., *point_shape]
Point on the manifold.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
Returns
-------
path : array-like, shape=[..., n_nodes, *point_shape]
Linear path between base point and point.
"""
times = gs.linspace(0.0, 1.0, self.n_nodes)
linear_deformation = point - base_point
return base_point + gs.einsum("t,...->t...", times, linear_deformation)
def _make_objective(self, point, base_point):
"""Create objective function.
Parameters
----------
point : array-like, shape=[*point_shape]
Point on the manifold.
base_point : array-like, shape=[*point_shape]
Point on the manifold.
Returns
-------
objective : callable
Objective function.
"""
def objective(midpoints):
"""Compute path energy of paths going through a midpoint.
Parameters
----------
midpoint : array-like, shape=[(self.n_nodes-2) * math.prod(*point_shape)]
Midpoints of the path.
Returns
-------
_ : float
Energy of the path.
"""
midpoints = gs.reshape(midpoints, (self.n_nodes - 2,) + self._space.shape)
path = gs.concatenate(
[
base_point,
midpoints,
point,
],
)
energy = self.path_energy(path)
if self.symmetric:
reversed_path = gs.flip(path, axis=0)
energy = (energy + self.path_energy(reversed_path)) / 2
return energy
return objective
def _discrete_geodesic_bvp_single(self, point, base_point):
"""Solve boundary value problem (BVP).
Given an initial point and an end point, solve the geodesic equation
via minimizing the Riemannian path energy.
Parameters
----------
point : array-like, shape=[*point_shape]
Point on the manifold.
base_point : array-like, shape=[*point_shape]
Point on the manifold.
Returns
-------
discr_geod_path : array-like, shape=[n_nodes, *point_shape]
Discrete geodesic.
"""
init_path = self.initialization(point, base_point)
init_midpoints = init_path[1:-1]
base_point = gs.expand_dims(base_point, axis=0)
point = gs.expand_dims(point, axis=0)
init_midpoints = gs.reshape(init_midpoints, (-1,))
objective = self._make_objective(point, base_point)
sol = self.optimizer.minimize(objective, init_midpoints)
solution_midpoints = gs.reshape(
gs.array(sol.x), (self.n_nodes - 2,) + self._space.shape
)
return gs.concatenate(
[
base_point,
solution_midpoints,
point,
],
axis=0,
)
[docs]
class MultiresPathStraightening(_DiscreteGeodesicBVPBatchMixins, PathBasedLogSolver):
"""Geodesic boundary value problem with multiresolution path straightening.
Parameters
----------
space : Manifold
Equipped manifold.
path_energy : callable
Method to compute Riemannian path energy.
n_nodes : list
Number of path discretization points at each resolution.
n_levels : int
Number of resolutions to use. Sets number of nodes following a sequence.
Ignored if ``n_nodes`` is not None.
optimizer : ScipyMinimize
An optimizer to solve path energy minimization problem.
initialization : callable
A method to get initial guess for optimization.
symmetric : bool
If to use a symmetrized version of the energy.
early_stop : bool
If to stop at a resolution if previous resolution was solved
with one iteration.
"""
def __init__(
self,
space,
path_energy=None,
n_nodes=None,
n_levels=3,
optimizer=None,
initialization=None,
symmetric=False,
early_stop=False,
):
super().__init__(space)
if optimizer is None:
optimizer = ScipyMinimize(
method="L-BFGS-B",
autodiff_jac=True,
options={"disp": False},
)
if n_nodes is None:
n_nodes = self._create_levels_sequence(3, n_levels)
if path_energy is None:
path_energy = UniformlySampledPathEnergy(space)
if initialization is None:
initialization = self._default_initialization
self.n_nodes = n_nodes
self.optimizer = optimizer
self.path_energy = path_energy
self.initialization = initialization
self.symmetric = symmetric
self.early_stop = early_stop
@staticmethod
def _create_levels_sequence(init_val, n_levels):
seq = [init_val]
for _ in range(n_levels - 1):
seq.append(seq[-1] * 2 - 1)
return seq
def _default_initialization(self, point, base_point):
"""Linear initialization.
Parameters
----------
point : array-like, shape=[..., *point_shape]
Point on the manifold.
base_point : array-like, shape=[..., *point_shape]
Point on the manifold.
Returns
-------
path : array-like, shape=[..., n_nodes, *point_shape]
Linear path between base point and point.
"""
times = gs.linspace(0.0, 1.0, self.n_nodes[0])
linear_deformation = point - base_point
return base_point + gs.einsum("t,...->t...", times, linear_deformation)
def _initialization_from_previous(self, path, n_nodes):
"""Initialize next resolution using previous resolution solution.
Parameters
----------
path : array-like, shape=[..., n_nodes_previous, *point_shape]
Discrete path.
Returns
-------
interp_path : array-like, shape=[..., n_nodes, *point_shape]
Discrete path.
"""
discr_geod_path = UniformlySampledDiscretePath(
path, point_ndim=self._space.point_ndim
)
times = gs.linspace(0.0, 1.0, n_nodes)
return discr_geod_path(times)
def _make_objective(self, point, base_point, n_nodes):
"""Create objective function.
Parameters
----------
point : array-like, shape=[*point_shape]
Point on the manifold.
base_point : array-like, shape=[*point_shape]
Point on the manifold.
Returns
-------
objective : callable
Objective function.
"""
def objective(midpoints):
"""Compute path energy of paths going through a midpoint.
Parameters
----------
midpoint : array-like, shape=[(n_nodes-2) * math.prod(*point_shape)]
Midpoints of the path.
Returns
-------
_ : float
Energy of the path.
"""
midpoints = gs.reshape(midpoints, (n_nodes - 2,) + self._space.shape)
path = gs.concatenate(
[
base_point,
midpoints,
point,
],
)
energy = self.path_energy(path)
if self.symmetric:
reversed_path = gs.flip(path, axis=0)
energy = (energy + self.path_energy(reversed_path)) / 2
return energy
return objective
def _discrete_geodesic_bvp_single(self, point, base_point):
"""Solve boundary value problem (BVP).
Given an initial point and an end point, solve the geodesic equation
via minimizing the Riemannian path energy.
Parameters
----------
point : array-like, shape=[*point_shape]
Point on the manifold.
base_point : array-like, shape=[*point_shape]
Point on the manifold.
Returns
-------
discr_geod_path : array-like, shape=[n_nodes, *point_shape]
Discrete geodesic.
"""
init_path = self.initialization(point, base_point)
init_midpoints = init_path[1:-1]
base_point = gs.expand_dims(base_point, axis=0)
point = gs.expand_dims(point, axis=0)
discrete_path = None
for n_nodes in self.n_nodes:
objective = self._make_objective(point, base_point, n_nodes)
if discrete_path is None:
init_midpoints = self.initialization(point, base_point)[1:-1]
else:
init_midpoints = self._initialization_from_previous(
discrete_path, n_nodes
)[1:-1]
init_midpoints = gs.reshape(init_midpoints, (-1,))
res = self.optimizer.minimize(objective, init_midpoints)
path_midpoints = gs.reshape(
gs.array(res.x), (n_nodes - 2,) + self._space.shape
)
discrete_path = gs.concatenate(
[
base_point,
path_midpoints,
point,
],
axis=0,
)
if self.early_stop and res.nit == 1:
break
return discrete_path