"""Statistical Manifold of Dirichlet distributions with the Fisher metric.
Lead author: Alice Le Brigant.
"""
import math
import numpy as np
from scipy.optimize import minimize
from scipy.stats import dirichlet
import geomstats.backend as gs
from geomstats.algebra_utils import from_vector_to_diagonal_matrix
from geomstats.geometry.base import VectorSpaceOpenSet
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.information_geometry.base import (
InformationManifoldMixin,
ScipyMultivariateRandomVariable,
)
from geomstats.numerics.bvp import ScipySolveBVP
from geomstats.numerics.geodesic import ExpODESolver, LogODESolver
from geomstats.numerics.ivp import ScipySolveIVP
from geomstats.vectorization import repeat_out
[docs]
class DirichletDistributions(InformationManifoldMixin, VectorSpaceOpenSet):
"""Class for the manifold of Dirichlet distributions.
This is Dirichlet = :math:`(R_+^*)^dim`, the positive quadrant of the
dim-dimensional Euclidean space.
Attributes
----------
dim : int
Dimension of the manifold of Dirichlet distributions.
"""
def __init__(self, dim, equip=True):
super().__init__(
dim=dim,
support_shape=(dim,),
embedding_space=Euclidean(dim=dim, equip=False),
equip=equip,
)
self._scp_rv = DirichletRandomVariable(self)
[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return DirichletMetric
[docs]
def belongs(self, point, atol=gs.atol):
"""Evaluate if a point belongs to the manifold of Dirichlet distributions.
Check that point defines parameters for a Dirichlet distributions,
i.e. belongs to the positive quadrant of the Euclidean space.
Parameters
----------
point : array-like, shape=[..., dim]
Point to be checked.
atol : float
Tolerance to evaluate positivity.
Optional, default: gs.atol
Returns
-------
belongs : array-like, shape=[...,]
Boolean indicating whether point represents a Dirichlet
distribution.
"""
belongs = point.shape[-1] == self.dim
if not belongs:
return gs.zeros(point.shape[:-1], dtype=bool)
return gs.all(point >= -atol, axis=-1)
[docs]
def random_point(self, n_samples=1, bound=5.0):
"""Sample parameters of Dirichlet distributions.
The uniform distribution on [0, bound]^dim is used.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound : float
Side of the square where the Dirichlet parameters are sampled.
Optional, default: 5.
Returns
-------
samples : array-like, shape=[..., dim]
Sample of points representing Dirichlet distributions.
"""
size = (self.dim,) if n_samples == 1 else (n_samples, self.dim)
return bound * gs.random.rand(*size)
[docs]
def projection(self, point, atol=gs.atol):
"""Project a point in ambient space to the open set.
The last coordinate is floored to `gs.atol` if it is negative.
Parameters
----------
point : array-like, shape=[..., dim]
Point in ambient space.
atol : float
Tolerance to evaluate positivity.
Returns
-------
projected : array-like, shape=[..., dim]
Projected point.
"""
return gs.where(point < atol, atol, point)
[docs]
def sample(self, point, n_samples=1):
"""Sample from the Dirichlet distribution.
Sample from the Dirichlet distribution with parameters provided
by point. This gives n_samples points in the simplex.
Parameters
----------
point : array-like, shape=[..., dim]
Point representing a Dirichlet distribution.
n_samples : int
Number of points to sample for each set of parameters in point.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., n_samples, dim]
Sample from the Dirichlet distributions.
"""
return self._scp_rv.rvs(point, n_samples)
[docs]
def point_to_pdf(self, point):
"""Compute pdf associated to point.
Compute the probability density function of the Dirichlet
distribution with parameters provided by point.
Parameters
----------
point : array-like, shape=[..., dim]
Point representing a beta distribution.
Returns
-------
pdf : function
Probability density function of the Dirichlet distribution with
parameters provided by point.
"""
return lambda x: self._scp_rv.pdf(x, point=point)
[docs]
class DirichletMetric(RiemannianMetric):
"""Class for the Fisher information metric on Dirichlet distributions."""
def __init__(self, space):
super().__init__(space=space)
self.log_solver = LogODESolver(
space, n_nodes=1000, integrator=ScipySolveBVP(max_nodes=1000)
)
self.exp_solver = ExpODESolver(space, integrator=ScipySolveIVP(method="LSODA"))
[docs]
def metric_matrix(self, base_point):
"""Compute the inner-product matrix.
Compute the inner-product matrix of the Fisher information metric
at the tangent space at base point.
Parameters
----------
base_point : array-like, shape=[..., dim]
Base point.
Returns
-------
mat : array-like, shape=[..., dim, dim]
Inner-product matrix.
"""
base_point = gs.to_ndarray(base_point, to_ndim=2)
n_points = base_point.shape[0]
mat_ones = gs.ones((n_points, self._space.dim, self._space.dim))
poly_sum = gs.polygamma(1, gs.sum(base_point, -1))
mat_diag = from_vector_to_diagonal_matrix(gs.polygamma(1, base_point))
mat = mat_diag - gs.einsum("i,ijk->ijk", poly_sum, mat_ones)
return gs.squeeze(mat)
[docs]
def christoffels(self, base_point):
"""Compute the Christoffel symbols.
Compute the Christoffel symbols of the Fisher information metric.
Parameters
----------
base_point : array-like, shape=[..., dim]
Base point.
Returns
-------
christoffels : array-like, shape=[..., dim, dim, dim]
Christoffel symbols, with the contravariant index on
the first dimension.
:math:`christoffels[..., i, j, k] = Gamma^i_{jk}`
References
----------
.. [LPP2021] A. Le Brigant, S. C. Preston, S. Puechmorel. Fisher-Rao
geometry of Dirichlet Distributions. Differential Geometry
and its Applications, 74, 101702, 2021.
"""
base_point = gs.to_ndarray(base_point, to_ndim=2)
n_points = base_point.shape[0]
def coefficients(ind_k):
"""Christoffel symbols for contravariant index ind_k."""
param_k = base_point[..., ind_k]
param_sum = gs.sum(base_point, -1)
c1 = (
1
/ gs.polygamma(1, param_k)
/ (
1 / gs.polygamma(1, param_sum)
- gs.sum(1 / gs.polygamma(1, base_point), -1)
)
)
c2 = -c1 * gs.polygamma(2, param_sum) / gs.polygamma(1, param_sum)
mat_ones = gs.ones((n_points, self._space.dim, self._space.dim))
mat_diag = from_vector_to_diagonal_matrix(
-gs.polygamma(2, base_point) / gs.polygamma(1, base_point)
)
arrays = [
gs.zeros((1, ind_k)),
gs.ones((1, 1)),
gs.zeros((1, self._space.dim - ind_k - 1)),
]
vec_k = gs.tile(gs.hstack(arrays), (n_points, 1))
val_k = gs.polygamma(2, param_k) / gs.polygamma(1, param_k)
vec_k = gs.einsum("i,ij->ij", val_k, vec_k)
mat_k = from_vector_to_diagonal_matrix(vec_k)
mat = (
gs.einsum("i,ijk->ijk", c2, mat_ones)
- gs.einsum("i,ijk->ijk", c1, mat_diag)
+ mat_k
)
return 1 / 2 * mat
christoffels = []
for ind_k in range(self._space.dim):
christoffels.append(coefficients(ind_k))
christoffels = gs.stack(christoffels, 1)
return gs.squeeze(christoffels)
[docs]
def jacobian_christoffels(self, base_point):
"""Compute the Jacobian of the Christoffel symbols.
Compute the Jacobian of the Christoffel symbols of the
Fisher information metric.
Parameters
----------
base_point : array-like, shape=[..., dim]
Base point.
Returns
-------
jac : array-like, shape=[..., dim, dim, dim, dim]
Jacobian of the Christoffel symbols.
:math:`jac[..., i, j, k, l] = dGamma^i_{jk} / dx_l`
"""
dim = self._space.dim
n_dim = base_point.ndim
param = gs.transpose(base_point)
sum_param = gs.sum(param, 0)
term_1 = 1 / gs.polygamma(1, param)
term_2 = 1 / gs.polygamma(1, sum_param)
term_3 = -gs.polygamma(2, param) / gs.polygamma(1, param) ** 2
term_4 = -gs.polygamma(2, sum_param) / gs.polygamma(1, sum_param) ** 2
term_5 = term_3 / term_1
term_6 = term_4 / term_2
term_7 = (
gs.polygamma(2, param) ** 2
- gs.polygamma(1, param) * gs.polygamma(3, param)
) / gs.polygamma(1, param) ** 2
term_8 = (
gs.polygamma(2, sum_param) ** 2
- gs.polygamma(1, sum_param) * gs.polygamma(3, sum_param)
) / gs.polygamma(1, sum_param) ** 2
term_9 = term_2 - gs.sum(term_1, 0)
jac_1 = term_1 * term_8 / term_9
jac_1_mat = gs.squeeze(gs.tile(jac_1, (dim, dim, dim, 1, 1)))
jac_2 = (
-term_6 / term_9**2 * gs.einsum("j...,i...->ji...", term_4 - term_3, term_1)
)
jac_2_mat = gs.squeeze(gs.tile(jac_2, (dim, dim, 1, 1, 1)))
jac_3 = term_3 * term_6 / term_9
jac_3_mat = gs.transpose(from_vector_to_diagonal_matrix(gs.transpose(jac_3)))
jac_3_mat = gs.squeeze(gs.tile(jac_3_mat, (dim, dim, 1, 1, 1)))
jac_4 = (
1
/ term_9**2
* gs.einsum("k...,j...,i...->kji...", term_5, term_4 - term_3, term_1)
)
jac_4_mat = gs.transpose(from_vector_to_diagonal_matrix(gs.transpose(jac_4)))
jac_5 = -gs.einsum("j...,i...->ji...", term_7, term_1) / term_9
jac_5_mat = from_vector_to_diagonal_matrix(gs.transpose(jac_5))
jac_5_mat = gs.transpose(from_vector_to_diagonal_matrix(jac_5_mat))
jac_6 = -gs.einsum("k...,j...->kj...", term_5, term_3) / term_9
jac_6_mat = gs.transpose(from_vector_to_diagonal_matrix(gs.transpose(jac_6)))
jac_6_mat = (
gs.transpose(
from_vector_to_diagonal_matrix(gs.transpose(jac_6_mat, [0, 1, 3, 2])),
[0, 1, 3, 4, 2],
)
if n_dim > 1
else from_vector_to_diagonal_matrix(jac_6_mat)
)
jac_7 = -from_vector_to_diagonal_matrix(gs.transpose(term_7))
jac_7_mat = from_vector_to_diagonal_matrix(jac_7)
jac_7_mat = gs.transpose(from_vector_to_diagonal_matrix(jac_7_mat))
jac = (
1
/ 2
* (
jac_1_mat
+ jac_2_mat
+ jac_3_mat
+ jac_4_mat
+ jac_5_mat
+ jac_6_mat
+ jac_7_mat
)
)
return (
gs.transpose(jac, [3, 1, 0, 2])
if n_dim == 1
else gs.transpose(jac, [4, 3, 1, 0, 2])
)
[docs]
def injectivity_radius(self, base_point=None):
"""Compute the radius of the injectivity domain.
This is is the supremum of radii r for which the exponential map is a
diffeomorphism from the open ball of radius r centered at the base point onto
its image.
In the case of the hyperbolic space, it does not depend on the base point and
is infinite everywhere, because of the negative curvature.
Parameters
----------
base_point : array-like, shape=[..., dim]
Point on the manifold.
Returns
-------
radius : array-like, shape=[...,]
Injectivity radius.
"""
radius = gs.array(math.inf)
return repeat_out(self._space.point_ndim, radius, base_point)
def _approx_geodesic_bvp(
self,
initial_point,
end_point,
degree=5,
method="BFGS",
n_times=200,
jac_on=True,
):
"""Solve approximation of the geodesic boundary value problem.
The space of solutions is restricted to curves whose coordinates are
polynomial functions of time. The boundary value problem is solved by
minimizing the energy among all such curves starting from initial_point
and ending at end_point, i.e. curves t -> (x_1(t),...,x_n(t)) where x_i
are polynomial functions of time t, such that (x_1(0),..., x_n(0)) is
initial_point and (x_1(1),..., x_n(1)) is end_point. The parameterized
curve is computed at n_times discrete times.
Parameters
----------
initial_point : array-like, shape=(dim,)
Starting point of the geodesic.
end_point : array-like, shape=(dim,)
End point of the geodesic.
degree : int
Degree of the coordinates' polynomial functions of time.
method : str
Minimization method to use in scipy.optimize.minimize.
n_times : int
Number of sample times.
jac_on : bool
If jac_on=True, use the Jacobian of the energy cost function in
scipy.optimize.minimize.
Returns
-------
dist : float
Length of the polynomial approximation of the geodesic.
curve : array-like, shape=(n_times, dim)
Polynomial approximation of the geodesic.
velocity : array-like, shape=(n_times, dim)
Velocity of the polynomial approximation of the geodesic.
"""
def cost_fun(param):
"""Compute the energy of the polynomial curve defined by param.
Parameters
----------
param : array-like, shape=(degree - 1, dim)
Parameters of the curve coordinates' polynomial functions of time.
Returns
-------
energy : float
Energy of the polynomial approximation of the geodesic.
length : float
Length of the polynomial approximation of the geodesic.
curve : array-like, shape=(n_times, dim)
Polynomial approximation of the geodesic.
velocity : array-like, shape=(n_times, dim)
Velocity of the polynomial approximation of the geodesic.
"""
last_coef = end_point - initial_point - gs.sum(param, axis=0)
coef = gs.vstack((initial_point, param, last_coef))
t = gs.linspace(0.0, 1.0, n_times)
t_curve = [t**i for i in range(degree + 1)]
t_curve = gs.stack(t_curve)
curve = gs.einsum("ij,ik->kj", coef, t_curve)
t_velocity = [i * t ** (i - 1) for i in range(1, degree + 1)]
t_velocity = gs.stack(t_velocity)
velocity = gs.einsum("ij,ik->kj", coef[1:], t_velocity)
if curve.min() < 0:
return np.inf, np.inf, curve, np.nan
velocity_sqnorm = self.squared_norm(vector=velocity, base_point=curve)
length = gs.sum(velocity_sqnorm ** (1 / 2)) / n_times
energy = gs.sum(velocity_sqnorm) / n_times
return energy, length, curve, velocity
def cost_jacobian(param):
"""Compute the jacobian of the cost function at polynomial curve.
Parameters
----------
param : array-like, shape=(degree - 1, dim)
Parameters of the curve coordinates' polynomial functions of time.
Returns
-------
jac : array-like, shape=(dim * (degree - 1),)
Jacobian of the cost function at polynomial curve.
"""
last_coef = end_point - initial_point - gs.sum(param, 0)
coef = gs.vstack((initial_point, param, last_coef))
t = gs.linspace(0.0, 1.0, n_times)
t_position = [t**i for i in range(degree + 1)]
t_position = gs.stack(t_position)
position = gs.einsum("ij,ik->kj", coef, t_position)
t_velocity = [i * t ** (i - 1) for i in range(1, degree + 1)]
t_velocity = gs.stack(t_velocity)
velocity = gs.einsum("ij,ik->kj", coef[1:], t_velocity)
fac1 = gs.stack(
[
k * t ** (k - 1) - degree * t ** (degree - 1)
for k in range(1, degree)
]
)
fac2 = gs.stack([t**k - t**degree for k in range(1, degree)])
fac3 = (velocity * gs.polygamma(1, position)).T - gs.sum(
velocity, 1
) * gs.polygamma(1, gs.sum(position, 1))
fac4 = (velocity**2 * gs.polygamma(2, position)).T - gs.sum(
velocity, 1
) ** 2 * gs.polygamma(2, gs.sum(position, 1))
cost_jac = (
2 * gs.einsum("ij,kj->ik", fac1, fac3)
+ gs.einsum("ij,kj->ik", fac2, fac4)
) / n_times
return cost_jac.T.reshape(dim * (degree - 1))
def f2minimize(x):
"""Compute function to minimize."""
param = gs.transpose(x.reshape((dim, degree - 1)))
res = cost_fun(param)
return res[0]
def jacobian(x):
"""Compute jacobian of the function to minimize."""
param = gs.transpose(x.reshape((dim, degree - 1)))
return cost_jacobian(param)
dim = initial_point.shape[0]
x0 = gs.ones(dim * (degree - 1))
jac = jacobian if jac_on else None
sol = minimize(f2minimize, x0, method=method, jac=jac)
opt_param = sol.x.reshape((dim, degree - 1)).T
_, dist, curve, velocity = cost_fun(opt_param)
return dist, curve, velocity
[docs]
class DirichletRandomVariable(ScipyMultivariateRandomVariable):
"""A Dirichlet random variable."""
def __init__(self, space):
pdf = lambda x, point: dirichlet.pdf(gs.transpose(x), point)
super().__init__(space, dirichlet.rvs, pdf)