Source code for geomstats.information_geometry.binomial

"""Statistical Manifold of Binomial distributions with the Fisher metric.

Lead authors: Jules Deschamps, Tra My Nguyen.
"""

from scipy.special import factorial
from scipy.stats import binom

import geomstats.backend as gs
from geomstats.geometry.base import VectorSpaceOpenSet
from geomstats.geometry.euclidean import Euclidean
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.information_geometry.base import (
    InformationManifoldMixin,
    ScipyUnivariateRandomVariable,
)


[docs] class BinomialDistributions(InformationManifoldMixin, VectorSpaceOpenSet): """Class for the manifold of binomial distributions. This is the parameter space of binomial distributions i.e. the half-line of positive reals. """ def __init__(self, n_draws, equip=True): super().__init__( dim=1, support_shape=(), embedding_space=Euclidean(dim=1, equip=False), equip=equip, ) self.n_draws = n_draws self._scp_rv = BinomialDistributionsRandomVariable(self)
[docs] @staticmethod def default_metric(): """Metric to equip the space with if equip is True.""" return BinomialMetric
[docs] def belongs(self, point, atol=gs.atol): """Evaluate if a point belongs to the manifold of binomial distributions. Parameters ---------- point : array-like, shape=[..., 1] Point to be checked. atol : float Tolerance to evaluate if point belongs to (0,1). Optional, default: gs.atol Returns ------- belongs : array-like, shape=[...,] Boolean indicating whether point represents a binomial distribution. """ belongs_shape = self.shape == point.shape[-self.point_ndim :] if not belongs_shape: shape = point.shape[: -self.point_ndim] return gs.zeros(shape, dtype=bool) return gs.squeeze(gs.logical_and(-atol <= point, point <= 1 + atol), axis=-1)
[docs] def random_point(self, n_samples=1): """Sample parameters of binomial distributions. The uniform distribution on [0, 1] is used. Parameters ---------- n_samples : int Number of samples. Optional, default: 1. Returns ------- samples : array-like, shape=[..., dim] Sample of points representing binomial distributions. """ size = (n_samples, self.dim) if n_samples != 1 else (self.dim,) return gs.random.rand(*size)
[docs] def projection(self, point, atol=gs.atol): """Project a point in ambient space to the parameter set. The parameter is floored to `gs.atol` if it is negative and to '1-gs.atol' if it is greater than 1. Parameters ---------- point : array-like, shape=[..., 1] Point in ambient space. atol : float Tolerance to evaluate positivity. Returns ------- projected : array-like, shape=[..., 1] Projected point. """ return gs.where( gs.logical_or(point < atol, point > 1 - atol), (1 - atol) * gs.cast(point > 1 - atol, point.dtype) + atol * gs.cast(point < atol, point.dtype), point, )
[docs] def sample(self, point, n_samples=1): """Sample from the binomial distribution. Sample from the binomial distribution with parameter provided by point. Parameters ---------- point : array-like, shape=[..., dim] Point representing a binomial distribution. n_samples : int Number of points to sample with for each parameter in point. Optional, default: 1. Returns ------- samples : array-like, shape=[..., n_samples, dim] Sample from binomial distributions. """ return self._scp_rv.rvs(point, n_samples)
[docs] def point_to_pdf(self, point): """Compute pmf associated to point. Compute the probability mass function of the binomial distribution with parameters provided by point. Parameters ---------- point : array-like, shape=[..., 1] Point representing a binomial distribution (probability of success). Returns ------- pmf : function Probability mass function of the binomial distribution with parameters provided by point. """ def pmf(k): """Generate parameterized function for binomial pmf. Parameters ---------- k : array-like, shape=[n_samples,] Integers in {0, ..., n_draws} at which to compute the probability mass function. Returns ------- pmf_at_k : array-like, shape=[..., n_samples] Values of pdf at k for each value of the parameters provided by point. """ k = gs.reshape(gs.array(k), (-1,)) const = factorial(self.n_draws) / ( factorial(k) * factorial(self.n_draws - k) ) return ( gs.from_numpy(const) * (point**k) * ((1 - point) ** (self.n_draws - k)) ) return pmf
[docs] class BinomialMetric(RiemannianMetric): """Class for the Fisher information metric on binomial distributions. References ---------- .. [AM1981] Atkinson, C., & Mitchell, A. F. (1981). Rao's distance measure. Sankhyā: The Indian Journal of Statistics, Series A, 345-365. """
[docs] def squared_dist(self, point_a, point_b): """Compute squared distance associated with the binomial Fisher Rao metric. Parameters ---------- point_a : array-like, shape=[..., 1] Point representing a binomial distribution (probability of success). point_b : array-like, shape=[..., 1] Point representing a binomial distribution (probability of success). Returns ------- squared_dist : array-like, shape=[...,] Squared distance between points point_a and point_b. """ return gs.squeeze( 4 * self._space.n_draws * (gs.arcsin(gs.sqrt(point_a)) - gs.arcsin(gs.sqrt(point_b))) ** 2, axis=-1, )
[docs] def metric_matrix(self, base_point): """Compute the metric matrix at the tangent space at base_point. Parameters ---------- base_point : array-like, shape=[..., 1] Point representing a binomial distribution. Returns ------- mat : array-like, shape=[..., 1, 1] Metric matrix. """ return gs.expand_dims( self._space.n_draws / (base_point * (1 - base_point)), axis=-1 )
@staticmethod def _geodesic_path(t, initial_phase, frequency): """Generate parameterized function for geodesic curve. Parameters ---------- t : array-like, shape=[n_times,] Times at which to compute points of the geodesics. Returns ------- geodesic : array-like, shape=[..., n_times, 1] Values of the geodesic at times t. """ return gs.expand_dims(gs.sin(frequency * t + initial_phase) ** 2, axis=-1) def _geodesic_ivp(self, initial_point, initial_tangent_vec): """Solve geodesic initial value problem. Compute the parameterized function for the geodesic starting at initial_point with initial velocity given by initial_tangent_vec. Parameters ---------- initial_point : array-like, shape=[..., 1] Initial point. initial_tangent_vec : array-like, shape=[..., 1] Tangent vector at initial point. Returns ------- path : function Parameterized function for the geodesic curve starting at initial_point with velocity initial_tangent_vec. """ initial_phase = gs.arcsin(gs.sqrt(initial_point)) frequency = initial_tangent_vec / ( 2.0 * gs.sqrt(initial_point) * gs.sqrt(1 - initial_point) ) return lambda t: self._geodesic_path(t, initial_phase, frequency) def _geodesic_bvp(self, initial_point, end_point): """Solve geodesic boundary problem. Compute the parameterized function for the geodesic starting at initial_point and ending at end_point. Parameters ---------- initial_point : array-like, shape=[..., 1] Initial point. end_point : array-like, shape=[..., 1] End point. Returns ------- path : function Parameterized function for the geodesic curve starting at initial_point and ending at end_point. """ initial_phase = gs.arcsin(gs.sqrt(initial_point)) frequency = gs.arcsin(gs.sqrt(end_point)) - initial_phase return lambda t: self._geodesic_path(t, initial_phase, frequency)
[docs] def exp(self, tangent_vec, base_point): """Compute exp map of a base point in tangent vector direction. Parameters ---------- tangent_vec : array-like, shape=[..., 1] Tangent vector at base point. base_point : array-like, shape=[..., 1] Base point. Returns ------- exp : array-like, shape=[..., 1] End point of the geodesic starting at base_point with initial velocity tangent_vec. """ return ( gs.sin( tangent_vec / (2.0 * gs.sqrt(base_point) * gs.sqrt(1 - base_point)) + gs.arcsin(gs.sqrt(base_point)) ) ** 2 )
[docs] def log(self, point, base_point): """Compute log map using a base point and an end point. Parameters ---------- point : array-like, shape=[..., 1] End point. base_point : array-like, shape=[..., 1] Base point. Returns ------- tangent_vec : array-like, shape=[..., 1] Initial velocity of the geodesic starting at base_point and reaching point at time 1. """ return ( 2 * gs.sqrt(base_point) * gs.sqrt(1 - base_point) * (gs.arcsin(gs.sqrt(point)) - gs.arcsin(gs.sqrt(base_point))) )
[docs] class BinomialDistributionsRandomVariable(ScipyUnivariateRandomVariable): """A binomial random variable.""" def __init__(self, space): rvs = lambda *args, **kwargs: binom.rvs(space.n_draws, *args, **kwargs) super().__init__(space, rvs) @staticmethod def _flatten_params(point, pre_flat_shape): flat_point = gs.reshape(gs.broadcast_to(point, pre_flat_shape), (-1,)) return {"p": flat_point}
[docs] def pdf(self, x, point): """Evaluate the probability density function at x. Parameters ---------- x : array-like, shape=[n_samples, *support_shape] Sample points at which to compute the probability density function. point : array-like, shape=[..., *space.shape] Point representing a distribution. Returns ------- pdf_at_x : array-like, shape=[..., n_samples, *support_shape] Values of pdf at x for each value of the parameters provided by point. """ return gs.from_numpy( binom.pmf(x, self.space.n_draws, point), )