"""Statistical Manifold of Gamma distributions with the Fisher metric.
The natural coordinate system for a Gamma Distribution is:
point = [kappa, nu], where kappa is the shape parameter, and nu the rate, or 1/scale.
However, information geometry most often works with standard coordinates, given by:
point = [kappa, gamma] = [kappa, kappa/nu].
The standard coordinate system is the convention we use in this script.
All points and all vectors input are assumed to be given in the standard coordinate
system unless stated otherwise.
Some of the methods in GammaDistributions allow to easily make the associated
change of variable, either for a point or a vector.
Lead author: Jules Deschamps.
"""
from scipy.stats import gamma
import geomstats.backend as gs
from geomstats.algebra_utils import from_vector_to_diagonal_matrix
from geomstats.geometry.base import VectorSpaceOpenSet
from geomstats.geometry.diffeo import InvolutionDiffeomorphism
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.information_geometry.base import (
InformationManifoldMixin,
ScipyUnivariateRandomVariable,
)
from geomstats.numerics.bvp import ScipySolveBVP
from geomstats.numerics.geodesic import ExpODESolver, LogODESolver
from geomstats.numerics.ivp import ScipySolveIVP
from geomstats.vectorization import get_batch_shape
[docs]
class NaturalToStandardDiffeo(InvolutionDiffeomorphism):
"""Diffeomorphism between natural and standard coordinates."""
def __call__(self, point):
"""Convert point from natural coordinates to standard coordinates.
The change of variable is symmetric.
Parameters
----------
point : array-like, shape=[..., 2]
Point of the Gamma manifold, given in natural coordinates.
Returns
-------
point : array-like, shape=[..., 2]
Point of the Gamma manifold, given in standard coordinates.
"""
return gs.stack([point[..., 0], point[..., 0] / point[..., 1]], axis=-1)
[docs]
def tangent(self, tangent_vec, base_point=None, image_point=None):
"""Convert tangent vector from natural coordinates to standard coordinates.
The change of variable is symmetric.
Parameters
----------
tangent_vec : array-like, shape=[..., 2]
Tangent vector at base_point, given in natural coordinates.
base_point : array-like, shape=[..., 2]
Point of the Gamma manifold, given in natural coordinates.
image_point : array-like, shape=[..., 2]
Point of the Gamma manifold, given in standard coordinates.
Returns
-------
image_tangent_vec : array-like, shape=[..., 2]
Tangent vector at base_point, given in standard coordinates.
"""
if base_point is None:
base_point = self(image_point)
kappa, scale = base_point[..., 0], base_point[..., 1]
jac_row_1 = gs.array([1.0, 0.0])
jac_row_2 = gs.stack([1 / scale, -kappa / scale**2], axis=-1)
point_batch_shape = get_batch_shape(1, base_point)
if point_batch_shape:
jac_row_1 = gs.broadcast_to(jac_row_1, point_batch_shape + (2,))
jac = gs.stack([jac_row_1, jac_row_2], axis=-2)
return gs.einsum("...jk,...k->...j", jac, tangent_vec)
[docs]
class GammaDistributions(InformationManifoldMixin, VectorSpaceOpenSet):
"""Class for the manifold of Gamma distributions.
This is :math:`Gamma = (R_+^*)^2`, the positive quadrant of the
2-dimensional Euclidean space.
"""
def __init__(self, equip=True):
super().__init__(
dim=2,
embedding_space=Euclidean(2, equip=False),
support_shape=(),
equip=equip,
)
self._scp_rv = GammaDistributionsRandomVariable(self)
[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return GammaMetric
[docs]
def belongs(self, point, atol=gs.atol):
"""Evaluate if a point belongs to the manifold of Gamma distributions.
Check that point defines parameters for a Gamma distribution,
i.e. belongs to the positive quadrant of the Euclidean space.
Parameters
----------
point : array-like, shape=[..., 2]
Point to be checked.
atol : float
Tolerance to evaluate positivity.
Optional, default: gs.atol
Returns
-------
belongs : array-like, shape=[...,]
Boolean indicating whether point represents a Gamma
distribution.
"""
point_dim = point.shape[-1]
belongs = point_dim == 2
return gs.logical_and(belongs, gs.all(point >= -atol, axis=-1))
[docs]
def random_point(self, n_samples=1, upper_bound=5.0, lower_bound=0.0):
"""Sample parameters of Gamma distributions.
The uniform distribution on [0, bound]^2 is used.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound : float
Side of the square where the Gamma parameters are sampled.
Optional, default: 5.
Returns
-------
samples : array-like, shape=[..., 2]
Sample of points representing Gamma distributions.
"""
upper_bound = upper_bound * gs.ones(2)
lower_bound = lower_bound * gs.ones(2)
if gs.any((upper_bound - lower_bound) < 0):
raise ValueError("upper_bound cannot be greater than lower_bound.")
size = (2,) if n_samples == 1 else (n_samples, 2)
return lower_bound + (upper_bound - lower_bound) * gs.random.rand(*size)
[docs]
def projection(self, point, atol=gs.atol):
"""Project a point in ambient space to the open set.
The last coordinate is floored to `gs.atol` if it is negative.
Parameters
----------
point : array-like, shape=[..., 2]
Point in ambient space.
atol : float
Tolerance to evaluate positivity.
Returns
-------
projected : array-like, shape=[..., 2]
Projected point.
"""
return gs.where(point < atol, atol, point)
[docs]
def sample(self, point, n_samples=1):
"""Sample from the Gamma distribution.
Sample from the Gamma distribution with parameters provided
by point. This gives n_samples points.
Parameters
----------
point : array-like, shape=[..., dim]
Point representing a Gamma distribution.
n_samples : int
Number of points to sample for each set of parameters in point.
Optional, default: 1.
Returns
-------
samples : array-like, shape=[..., n_samples]
Sample from the Gamma distributions.
"""
return self._scp_rv.rvs(point, n_samples)
[docs]
def point_to_pdf(self, point):
"""Compute pdf associated to point.
Compute the probability density function of the Gamma
distribution with parameters provided by point.
Parameters
----------
point : array-like, shape=[..., dim]
Point representing a Gamma distribution.
Returns
-------
pdf : function
Probability density function of the Gamma distribution with
parameters provided by point.
"""
kappa = gs.expand_dims(point[..., 0], axis=-1)
gamma = gs.expand_dims(point[..., 1], axis=-1)
def pdf(x):
"""Generate parameterized function for Gamma pdf.
Parameters
----------
x : array-like, shape=[n_samples,]
Points at which to compute the probability
density function.
Returns
-------
pdf_at_x : array-like, shape=[..., n_samples]
Values of pdf at x for each value of the parameters provided
by point.
"""
x = gs.reshape(gs.array(x), (-1,))
return (
kappa**kappa
* x ** (kappa - 1)
* gs.exp(-kappa * x / gamma)
/ (gamma**kappa * gs.gamma(kappa))
)
return pdf
[docs]
@staticmethod
def maximum_likelihood_fit(data):
"""Estimate parameters from samples.
This is a wrapper around scipy's maximum likelihood estimator to
estimate the parameters of a gamma distribution from samples.
Parameters
----------
data : list or list of lists/arrays
Data to estimate parameters from. Lists of
different length may be passed.
Returns
-------
parameter : array-like, shape=[..., 2]
Estimate of parameter obtained by maximum likelihood.
"""
def is_nested(sample):
"""Check if sample contains an iterable."""
for el in sample:
try:
return iter(el)
except TypeError:
return False
if not is_nested(data):
data = [data]
parameters = []
for sample in data:
sample = gs.array(sample)
kappa, _, scale = gamma.fit(sample, floc=0)
nu = 1 / scale
parameters.append(gs.array([kappa, kappa / nu]))
return parameters[0] if len(data) == 1 else gs.stack(parameters)
[docs]
class GammaMetric(RiemannianMetric):
"""Class for the Fisher information metric on Gamma distributions.
References
----------
.. [AD2008] Arwini, K. A., & Dodson, C. T. (2008).
Information geometry (pp. 31-54). Springer Berlin Heidelberg.
"""
def __init__(self, space):
super().__init__(space=space)
self.log_solver = LogODESolver(
space, n_nodes=1000, integrator=ScipySolveBVP(max_nodes=1000)
)
self.exp_solver = ExpODESolver(space, integrator=ScipySolveIVP(method="LSODA"))
[docs]
def metric_matrix(self, base_point):
"""Compute the inner-product matrix.
Compute the inner-product matrix of the Fisher information metric
at the tangent space at base point.
Parameters
----------
base_point : array-like, shape=[..., 2]
Base point.
Returns
-------
mat : array-like, shape=[..., 2, 2]
Inner-product matrix.
"""
kappa, gamma = base_point[..., 0], base_point[..., 1]
mat_diag = gs.stack(
[gs.polygamma(1, kappa) - 1 / kappa, kappa / gamma**2], axis=-1
)
return from_vector_to_diagonal_matrix(mat_diag)
[docs]
def christoffels(self, base_point):
"""Compute the Christoffel symbols.
Compute the Christoffel symbols of the Fisher information metric.
For computation purposes, we replace the value of
(gs.polygamma(1, x) - 1/x) by an equivalent (close lower-bound) when it becomes
too difficult to compute, as per in [GQ2015]_.
Parameters
----------
base_point : array-like, shape=[..., 2]
Base point.
Returns
-------
christoffels : array-like, shape=[..., 2, 2, 2]
Christoffel symbols, with the contravariant index on
the first dimension.
:math:`christoffels[..., i, j, k] = Gamma^i_{jk}`
References
----------
.. [AD2008] Arwini, K. A., & Dodson, C. T. (2008).
Information geometry (pp. 31-54). Springer Berlin Heidelberg.
.. [GQ2015] Guo, B. N., Qi, F., Zhao, J. L., & Luo, Q. M. (2015).
Sharp inequalities for polygamma functions.
Mathematica Slovaca, 65(1), 103-120.
"""
base_point = gs.to_ndarray(base_point, to_ndim=2)
kappa, gamma = base_point[:, 0], base_point[:, 1]
if gs.any(kappa > 4e15):
raise ValueError(
"Christoffels computation overflows with values of kappa. "
"All values of kappa < 4e15 work."
)
shape = kappa.shape
c111 = gs.where(
gs.polygamma(1, kappa) - 1 / kappa > gs.atol,
(gs.polygamma(2, kappa) + gs.array(kappa) ** -2)
/ (2 * (gs.polygamma(1, kappa) - 1 / kappa)),
0.25 * (kappa**2 * gs.polygamma(2, kappa) + 1),
)
c122 = gs.where(
gs.polygamma(1, kappa) - 1 / kappa > gs.atol,
-1 / (2 * gamma**2 * (gs.polygamma(1, kappa) - 1 / kappa)),
-(kappa**2) / (4 * gamma**2),
)
c1 = gs.squeeze(
from_vector_to_diagonal_matrix(gs.transpose(gs.array([c111, c122])))
)
c2 = gs.squeeze(
gs.transpose(
gs.array(
[[gs.zeros(shape), 1 / (2 * kappa)], [1 / (2 * kappa), -1 / gamma]]
)
)
)
christoffels = gs.array([c1, c2])
if len(christoffels.shape) == 4:
christoffels = gs.transpose(christoffels, [1, 0, 2, 3])
return gs.squeeze(christoffels)
[docs]
def jacobian_christoffels(self, base_point):
"""Compute the Jacobian of the Christoffel symbols.
Compute the Jacobian of the Christoffel symbols of the
Fisher information metric.
For computation purposes, we replace the value of
(gs.polygamma(1, x) - 1/x) and (gs.polygamma(2,x) + 1/x**2) by an equivalent
(close bounds) when they become too difficult to compute.
Parameters
----------
base_point : array-like, shape=[..., 2]
Base point.
Returns
-------
jac : array-like, shape=[..., 2, 2, 2, 2]
Jacobian of the Christoffel symbols.
:math:`jac[..., i, j, k, l] = dGamma^i_{jk} / dx_l`
References
----------
.. [GQ2015] Guo, B. N., Qi, F., Zhao, J. L., & Luo, Q. M. (2015).
Sharp inequalities for polygamma functions.
Mathematica Slovaca, 65(1), 103-120.
"""
base_point = gs.to_ndarray(base_point, 2)
n_points = base_point.shape[0]
kappa, gamma = base_point[:, 0], base_point[:, 1]
term_0 = gs.zeros((n_points))
term_1 = 1 / gamma**2
term_2 = gs.where(
gs.polygamma(1, kappa) - 1 / kappa > gs.atol,
kappa / (gamma**3 * (kappa * gs.polygamma(1, kappa) - 1)),
kappa**2 / gamma**3,
)
term_3 = -1 / (2 * kappa**2)
term_4 = gs.where(
gs.polygamma(1, kappa) - 1 / kappa > gs.atol,
(kappa**2 * gs.polygamma(2, kappa) + 1)
/ (2 * gamma**2 * (kappa * gs.polygamma(1, kappa) - 1) ** 2),
(kappa**4 * gs.polygamma(2, kappa) + kappa**2) / (2 * gamma**2),
)
term_5 = gs.where(
gs.polygamma(1, kappa) - 1 / kappa > gs.atol,
(
kappa**4
* (
gs.polygamma(1, kappa) * gs.polygamma(3, kappa)
- gs.polygamma(2, kappa) ** 2
)
- kappa**3 * gs.polygamma(3, kappa)
- 2 * kappa**2 * gs.polygamma(2, kappa)
- 2 * kappa * gs.polygamma(1, kappa)
+ 1
)
/ (2 * (kappa**2 * gs.polygamma(1, kappa) - kappa) ** 2),
0.5
* (
kappa**4
* (
gs.polygamma(1, kappa) * gs.polygamma(3, kappa)
- gs.polygamma(2, kappa) ** 2
)
- kappa**3 * gs.polygamma(3, kappa)
- 2 * kappa**2 * gs.polygamma(2, kappa)
- 2 * kappa * gs.polygamma(1, kappa)
+ 1
),
)
jac = gs.array(
[
[
[[term_5, term_0], [term_0, term_0]],
[[term_0, term_0], [term_4, term_2]],
],
[
[[term_0, term_0], [term_3, term_0]],
[[term_3, term_0], [term_0, term_1]],
],
]
)
if n_points > 1:
jac = gs.transpose(jac, [4, 0, 1, 2, 3])
return gs.squeeze(jac)
[docs]
class GammaDistributionsRandomVariable(ScipyUnivariateRandomVariable):
"""A gamma random variable."""
def __init__(self, space):
super().__init__(space, gamma.rvs, gamma.pdf)
@staticmethod
def _flatten_params(point, pre_flat_shape):
param_a = gs.expand_dims(point[..., 0], axis=-1)
scale = gs.expand_dims(point[..., 1] / point[..., 0], axis=-1)
flat_param_a = gs.reshape(gs.broadcast_to(param_a, pre_flat_shape), (-1,))
flat_scale = gs.reshape(gs.broadcast_to(scale, pre_flat_shape), (-1,))
return {"a": flat_param_a, "scale": flat_scale}