Source code for geomstats.information_geometry.beta

"""Statistical Manifold of beta distributions with the Fisher metric.

Lead author: Alice Le Brigant.
"""

from scipy.stats import beta

import geomstats.backend as gs
from geomstats.information_geometry.base import ScipyUnivariateRandomVariable
from geomstats.information_geometry.dirichlet import (
    DirichletDistributions,
    DirichletMetric,
)


[docs] class BetaDistributions(DirichletDistributions): r"""Class for the manifold of beta distributions. This is Beta = :math:`R_+^* \times R_+^*`, the upper-right quadrant of the Euclidean plane. Attributes ---------- dim : int Dimension of the manifold of beta distributions, equal to 2. embedding_space : Manifold Embedding manifold. """ def __init__(self, equip=True): super().__init__(dim=2, equip=equip) self.support_shape = () self._scp_rv = BetaDistributionsRandomVariable(self)
[docs] @staticmethod def default_metric(): """Metric to equip the space with if equip is True.""" return BetaMetric
[docs] def sample(self, point, n_samples=1): """Sample from the beta distribution. Sample from the beta distribution with parameters provided by point. This gives samples in the segment [0, 1]. Parameters ---------- point : array-like, shape=[..., 2] Point representing a beta distribution. n_samples : int Number of points to sample with each pair of parameters in point. Optional, default: 1. Returns ------- samples : array-like, shape=[..., n_samples] Sample from beta distributions.) """ return self._scp_rv.rvs(point, n_samples)
[docs] def point_to_pdf(self, point): """Compute pdf associated to point. Compute the probability density function of the beta distribution with parameters provided by point. Parameters ---------- point : array-like, shape=[..., 2] Point representing a beta distribution. Returns ------- pdf : function Probability density function of the beta distribution with parameters provided by point. """ alpha = gs.expand_dims(point[..., 0], axis=-1) beta = gs.expand_dims(point[..., 1], axis=-1) def pdf(x): """Generate parameterized function for normal pdf. Parameters ---------- x : array-like, shape=[n_samples,] Points at which to compute the probability density function. Returns ------- pdf_at_x : array-like, shape=[..., n_samples] Values of pdf at x for each value of the parameters provided by point. """ x = gs.reshape(gs.array(x), (-1,)) return ( x ** (alpha - 1) * (1 - x) ** (beta - 1) / (gs.gamma(alpha) * gs.gamma(beta) / gs.gamma(alpha + beta)) ) return pdf
[docs] @staticmethod def maximum_likelihood_fit(data, loc=0, scale=1, epsilon=1e-6): """Estimate parameters from samples. This a wrapper around scipy's maximum likelihood estimator to estimate the parameters of a beta distribution from samples. Parameters ---------- data : array-like, shape=[..., n_samples] Data to estimate parameters from. Arrays of different length may be passed. loc : float Location parameter of the distribution to estimate parameters from. It is kept fixed during optimization. Optional, default: 0. scale : float Scale parameter of the distribution to estimate parameters from. It is kept fixed during optimization. Optional, default: 1. Returns ------- parameter : array-like, shape=[..., 2] Estimate of parameter obtained by maximum likelihood. """ data = gs.where(data == 1.0, 1.0 - epsilon, data) data = gs.where(data == 0.0, epsilon, data) data = gs.to_ndarray(data, to_ndim=2) parameters = [] for sample in data: param_a, param_b, _, _ = beta.fit(sample, floc=loc, fscale=scale) parameters.append(gs.array([param_a, param_b])) return parameters[0] if len(data) == 1 else gs.stack(parameters)
[docs] class BetaMetric(DirichletMetric): """Class for the Fisher information metric on beta distributions."""
[docs] @staticmethod def metric_det(point): """Compute the determinant of the metric. Parameters ---------- point : array-like, shape=[..., 2] Point representing a beta distribution. Returns ------- metric_det : array-like, shape=[...,] Determinant of the metric. """ param_a = point[..., 0] param_b = point[..., 1] metric_det = gs.polygamma(1, param_a) * gs.polygamma(1, param_b) - gs.polygamma( 1, param_a + param_b ) * (gs.polygamma(1, param_a) + gs.polygamma(1, param_b)) return metric_det
[docs] class BetaDistributionsRandomVariable(ScipyUnivariateRandomVariable): """A beta random variable.""" def __init__(self, space): super().__init__(space, beta.rvs, beta.pdf) @staticmethod def _flatten_params(point, pre_flat_shape): param_a = gs.expand_dims(point[..., 0], axis=-1) param_b = gs.expand_dims(point[..., 1], axis=-1) flat_param_a = gs.reshape(gs.broadcast_to(param_a, pre_flat_shape), (-1,)) flat_param_b = gs.reshape(gs.broadcast_to(param_b, pre_flat_shape), (-1,)) return {"a": flat_param_a, "b": flat_param_b}