Source code for geomstats.numerics.bvp
"""Boundary value problem solvers implementation."""
import scipy
import geomstats.backend as gs
from geomstats.numerics._common import result_to_backend_type
[docs]
class ScipySolveBVP:
"""Wrapper for scipy.integrate.solve_bvp."""
def __init__(self, tol=1e-3, max_nodes=1000, bc_tol=None, save_result=False):
self.tol = tol
self.max_nodes = max_nodes
self.bc_tol = bc_tol
self.save_result = save_result
self.result_ = None
[docs]
def integrate(self, fun, bc, x, y, fun_jac=None, bc_jac=None):
"""Solve a boundary value problem for a system of ODEs."""
def fun_(t, state):
return fun(t, gs.from_numpy(state))
def bc_(state_0, state_1):
return bc(gs.from_numpy(state_0), gs.from_numpy(state_1))
if fun_jac is not None:
def fun_jac_(t, state):
return fun_jac(t, gs.from_numpy(state))
else:
fun_jac_ = None
result = scipy.integrate.solve_bvp(
fun_,
bc_,
x,
y,
tol=self.tol,
max_nodes=self.max_nodes,
fun_jac=fun_jac_,
bc_jac=bc_jac,
bc_tol=self.bc_tol,
)
result = result_to_backend_type(result)
if self.save_result:
self.result_ = result
return result