"""Initial value problem solvers implementation."""
from abc import ABC, abstractmethod
import scipy
import geomstats.backend as gs
import geomstats.integrator as gs_integrator
from geomstats.errors import check_parameter_accepted_values
from geomstats.numerics._common import result_to_backend_type
def _merge_scipy_results(results, same_t=False):
keys = list(results[0].keys())
merged_results = {key: [] for key in keys}
for result in results:
for key, value in merged_results.items():
value.append(result[key])
if same_t:
merged_results["t"] = merged_results["t"][0]
merged_results["y"] = gs.stack(merged_results["y"])
return merged_results
[docs]
class OdeResult(scipy.optimize.OptimizeResult):
"""Bunch object (follows scipy).
Its purposes is to homogenize output of different integrators.
"""
[docs]
def get_last_y(self):
"""Get value for last y.
Allows to have y represented as an `gs.array` or `list[gs.array]` (latter
for cases where they have different shapes).
Assumes last `t` is the same.
"""
if isinstance(self.y, (list, tuple)):
return gs.stack([y_[-1] for y_ in self.y])
return self.y[-1]
[docs]
class ODEIVPIntegrator(ABC):
"""Abstract class for ode ivp solvers.
Parameters
----------
save_result : bool
If True, result is stored after calling `integrate` or `integrate_t`.
tchosen : bool
Informs about ability to solve at chosen times.
If False, then does not implement `integrate_t`.
"""
def __init__(self, save_result=False, tchosen=False):
self.save_result = save_result
self.tchosen = tchosen
self.result_ = None
[docs]
@abstractmethod
def integrate(self, force, initial_state, end_time):
"""Integrate force.
Parameters
----------
force : callable
Function to integrate: `f(state, t)`.
initial_state : array-like, shape=[..., n_vars, *point_shape]
Initial state.
end_time : float or None
Integration end time.
Returns
-------
result : OdeResult
"""
[docs]
def integrate_t(self, force, initial_state, t_eval):
"""Integrate force while choosing evaluating points.
Parameters
----------
force : callable
Function to integrate: `f(state, t)`.
initial_state : array-like, shape=[..., n_vars, *point_shape]
Initial state.
t_eval : array-like
Times at which to store the computed solution.
Returns
-------
result : OdeResult
"""
raise NotImplementedError("Can't solve for chosen evaluating points.")
[docs]
class GSIVPIntegrator(ODEIVPIntegrator):
"""In-house ODE integrator.
Parameters
----------
n_steps : int
Number of steps to perform.
step_type : str
Type of integration step.
Possible values are `euler`, `rk2`, `rk4`.
save_result : bool
If True, result is stored after calling `integrate` or `integrate_t`.
"""
def __init__(self, n_steps=10, step_type="euler", save_result=False):
super().__init__(save_result=save_result, tchosen=False)
self.step_type = step_type
self.n_steps = n_steps
@property
def step_type(self):
"""Integrator step type."""
return self._step_type
@step_type.setter
def step_type(self, value):
if callable(value):
step_function = value
value = None
else:
check_parameter_accepted_values(
value, "step_type", gs_integrator.STEP_FUNCTIONS
)
step_function = getattr(gs_integrator, gs_integrator.STEP_FUNCTIONS[value])
self._step_function = step_function
self._step_type = value
def _step(self, force, state, time, dt):
return self._step_function(force, state, time, dt)
def _get_n_fevals(self, n_steps):
n_evals_step = gs_integrator.FEVALS_PER_STEP[self.step_type]
return n_evals_step * n_steps
def _integrate(self, force, initial_state, end_time=1.0):
dt = end_time / self.n_steps
states = [initial_state]
current_state = initial_state
for i in range(self.n_steps):
current_state = self._step(
force=force, state=current_state, time=i * dt, dt=dt
)
states.append(current_state)
return states
[docs]
def integrate(self, force, initial_state, end_time=1.0):
"""Integrate force.
Parameters
----------
force : callable
Function to integrate: `f(state, t)`.
initial_state : array-like, shape=[..., n_vars, *point_shape]
Initial state.
end_time : float or None
Integration end time.
Returns
-------
result : OdeResult
"""
states = self._integrate(force, initial_state, end_time=end_time)
ts = gs.linspace(0.0, end_time, self.n_steps + 1)
nfev = self._get_n_fevals(self.n_steps)
result = OdeResult(t=ts, y=gs.array(states), nfev=nfev, njev=0, sucess=True)
if self.save_result:
self.result_ = result
return result
[docs]
class ScipySolveIVP(ODEIVPIntegrator):
"""Wrapper for scipy.integrate.solve_ivp.
Check
https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
for additional options.
Parameters
----------
method : str
Integration method.
save_result : bool
If True, result is stored after calling `integrate` or `integrate_t`.
point_ndim = int
Dimension of array representing a point in the space.
"""
def __init__(self, method="RK45", save_result=False, point_ndim=1, **options):
super().__init__(save_result=save_result, tchosen=True)
self.method = method
self.point_ndim = point_ndim
self.options = options
def _integrate(self, force, initial_state, end_time=1.0, t_eval=None):
if initial_state.ndim > (self.point_ndim + 1):
results = []
for initial_state_ in initial_state:
results.append(
self._integrate_single(force, initial_state_, end_time, t_eval)
)
result = OdeResult(_merge_scipy_results(results, same_t=t_eval is not None))
else:
result = self._integrate_single(
force, initial_state, end_time, t_eval=t_eval
)
result = OdeResult(**result)
if self.save_result:
self.result_ = result
return result
[docs]
def integrate(self, force, initial_state, end_time=1.0):
"""Integrate force.
Parameters
----------
force : callable
Function to integrate: `f(state, t)`.
initial_state : array-like, shape=[..., n_vars, *point_shape]
Initial state.
end_time : float or None
Integration end time.
Returns
-------
result : OdeResult
"""
return self._integrate(force, initial_state, end_time=end_time)
[docs]
def integrate_t(self, force, initial_state, t_eval):
"""Integrate force at `t_eval` points.
Parameters
----------
force : callable
Function to integrate: `f(state, t)`.
initial_state : array-like, shape=[..., n_vars, *point_shape]
Initial state.
t_eval : array-like
Times at which to store the computed solution.
Returns
-------
result : OdeResult
"""
return self._integrate(force, initial_state, end_time=t_eval[-1], t_eval=t_eval)
def _integrate_single(self, force, initial_state, end_time=1.0, t_eval=None):
def force_(t, state):
state = gs.from_numpy(state)
unraveled_state = gs.reshape(state, initial_state.shape)
return gs.reshape(
force(unraveled_state, t),
state.shape,
)
raveled_initial_state = gs.reshape(initial_state, (-1,))
result = scipy.integrate.solve_ivp(
force_,
(0.0, end_time),
raveled_initial_state,
method=self.method,
t_eval=t_eval,
**self.options,
)
result = result_to_backend_type(result)
result.y = gs.reshape(
gs.moveaxis(result.y, 0, -1),
(-1,) + initial_state.shape,
)
return result