Source code for geomstats.information_geometry.fisher_rao_metric

"""Class to implement simply the Fisher-Rao metric on information manifolds."""

from scipy.integrate import quad_vec

import geomstats.backend as gs
from geomstats.geometry.riemannian_metric import RiemannianMetric


[docs] class FisherRaoMetric(RiemannianMetric): r"""Class to derive the information metric from the pdf in InformationManifoldMixin. Given a statistical manifold with coordinates :math:`\theta`, the Fisher information metric is: .. math:: g_{j k}(\theta)=\int_X \frac{\partial \log p(x, \theta)} {\partial \theta_j}\frac{\partial \log p(x, \theta)} {\partial \theta_k} p(x, \theta) d x Attributes ---------- space : InformationManifold Riemannian Manifold for a parametric family of (real) distributions. support : list, shape = (2,) Left and right bounds for the support of the distribution. But this is just to help integration, bounds should be as large as needed. """ def __init__(self, space, support): super().__init__( space=space, signature=(space.dim, 0), ) self.support = support
[docs] def metric_matrix(self, base_point): r"""Compute the inner-product matrix. The Fisher information matrix is noted I in the literature. Compute the inner-product matrix of the Fisher information metric at the tangent space at base point. .. math:: I(\theta) = \mathbb{E}[YY^T] where, .. math:: Y = \nabla_{\theta} \log f_{\theta}(x) After manipulation and in indicial notation .. math:: I_{ij} = \int \ \partial_{i} f_{\theta}(x)\ \partial_{j} f_{\theta}(x)\ \frac{1}{f_{\theta}(x)} Parameters ---------- base_point : array-like, shape=[..., dim] Base point. Returns ------- metric_mat : array-like, shape=[..., dim, dim] Inner-product matrix. References ---------- .. [AS1985] Amari, S (1985) Differential Geometric Methods in Statistics, Berlin, Springer – Verlag. """ def pdf(x): """Compute pdf at a fixed point on the support. Parameters ---------- x : float, shape (,) Point on the support of the distribution """ return lambda point: gs.squeeze(self._space.point_to_pdf(point)(x), axis=-1) def _function_to_integrate(x): pdf_x = pdf(x) ( pdf_x_at_base_point, pdf_x_derivative_at_base_point, ) = gs.autodiff.value_and_grad(pdf_x)(base_point) return gs.einsum( "...ij,...->...ij", gs.einsum( "...i,...j->...ij", pdf_x_derivative_at_base_point, pdf_x_derivative_at_base_point, ), 1 / pdf_x_at_base_point, ) return quad_vec(_function_to_integrate, *self.support)[0]
[docs] def inner_product_derivative_matrix(self, base_point): r"""Compute the derivative of the inner-product matrix. Compute the derivative of the inner-product matrix of the Fisher information metric at the tangent space at base point. .. math:: \partial_{\theta} I(\theta) = \partial_{\theta} \mathbb{E}[YY^T] where, .. math:: Y = \nabla_{\theta} \log f_{\theta}(x) or, in indicial notation: .. math:: \partial_k I_{ij} = \int\ \partial_{ki}^2 f\partial_j f \frac{1}{f} + \ \partial_{kj}^2 f\partial_i f \frac{1}{f} - \ \partial_i f \partial_j f \partial_k f \frac{1}{f^2} with :math:`f = f_{\theta}(x)` Parameters ---------- base_point : array-like, shape=[..., dim] Base point. Returns ------- mat : array-like, shape=[..., dim, dim, dim] Derivative of the inner-product matrix, where the index k of the derivation is last: math:`mat_{ijk} = \partial_k g_{ij}`. """ def pdf(x): """Compute pdf at a fixed point on the support. Parameters ---------- x : float, shape (,) Point on the support of the distribution """ return lambda point: gs.squeeze(self._space.point_to_pdf(point)(x), axis=-1) def _function_to_integrate(x): pdf_x = pdf(x) ( pdf_x_at_base_point, pdf_x_derivative_at_base_point, pdf_x_hessian_at_base_point, ) = gs.autodiff.value_jacobian_and_hessian(pdf_x)(base_point) return gs.einsum( "...,...ijk->...ijk", 1 / (pdf_x_at_base_point**2), gs.einsum( "...,...ijk->...ijk", pdf_x_at_base_point, gs.einsum( "...ki,...j->...ijk", pdf_x_hessian_at_base_point, pdf_x_derivative_at_base_point, ) + gs.einsum( "...kj,...i->...ijk", pdf_x_hessian_at_base_point, pdf_x_derivative_at_base_point, ), ) - gs.einsum( "...i, ...j, ...k -> ...ijk", pdf_x_derivative_at_base_point, pdf_x_derivative_at_base_point, pdf_x_derivative_at_base_point, ), ) return quad_vec(_function_to_integrate, *self.support)[0]