# Source code for geomstats.learning.expectation_maximization

"""Expectation maximization algorithm.

"""

import logging

from sklearn.base import BaseEstimator, ClusterMixin

import geomstats.backend as gs
from geomstats.learning._template import TransformerMixin
from geomstats.learning.frechet_mean import FrechetMean, variance
from geomstats.learning.kmeans import RiemannianKMeans

PDF_TOL = 1e-6
SUM_CHECK_PDF = 1e-4
MIN_VAR_INIT = 1e-3

[docs]
class GaussianMixtureModel:
r"""Gaussian mixture model (GMM).

Parameters
----------
space : Manifold
Equipped manifold.
means : array-like, shape=[n_gaussians, dim]
Means of each component of the GMM.
variances : array-like, shape=[n_gaussians,]
Variances of each component of the GMM.

Attributes
----------
normalization_factor_var : array-like, shape=[n_variances,]
Array of computed normalization factor.
variances_range : array-like, shape=[n_variances,]
Array of standard deviations.
phi_inv_var : array-like, shape=[n_variances,]
Array of the computed inverse of a function phi
whose expression is closed-form
:math:\sigma\mapsto \sigma^3 \times \frac{d  }
{\mathstrut d\sigma}\log \zeta_m(\sigma)'
where :math:'\sigma' denotes the variance
and :math:'\zeta' the normalization coefficient
and :math:'m' the dimension.
"""

def __init__(
self,
space,
means=None,
variances=None,
zeta_lower_bound=5e-1,
zeta_upper_bound=2.0,
zeta_step=0.01,
):
self.space = space
self.means = means
self.variances = variances

self.zeta_lower_bound = zeta_lower_bound
self.zeta_upper_bound = zeta_upper_bound
self.zeta_step = zeta_step

(
self.variances_range,
self.normalization_factor_var,
self.phi_inv_var,
) = self._normalization_factor_init()

def _normalization_factor_init(self):
r"""Set up function for the normalization factor.

The normalization factor is used to define Gaussian distributions
at initialization.
"""
variances = gs.arange(
self.zeta_lower_bound, self.zeta_upper_bound, self.zeta_step
)
normalization_factor_var = self.space.metric.normalization_factor(variances)

cond_1 = normalization_factor_var.sum() != normalization_factor_var.sum()
cond_2 = normalization_factor_var.sum() == float("+inf")
cond_3 = normalization_factor_var.sum() == float("-inf")

if cond_1 or cond_2 or cond_3:
logging.warning("Untractable normalization factor :")

limit_nf = (
((normalization_factor_var / normalization_factor_var) * 0)
.nonzero()
.item()
)
max_nf = len(variances)
variances = variances[0:limit_nf]
normalization_factor_var = normalization_factor_var[0:limit_nf]
if cond_1:
logging.warning("\t Nan value " "in processing normalization factor")
if cond_2 or cond_3:
raise ValueError("\t +-inf value in " "processing normalization factor")

logging.warning("\t Max variance is now : %s", str(variances[-1]))
logging.warning(
"\t Number of possible variance is now: %s / %s ",
str(len(variances)),
str(max_nf),
)

return variances, normalization_factor_var, phi_inv_var

[docs]
def pdf(self, data):
"""Return the separate probability density function of GMM.

The probability density function is computed for
each component of the GMM separately (i.e., mixture coefficients
are not taken into account).

Parameters
----------
data : array-like, shape=[n_samples, dim]
Points at which the GMM probability density is computed.

Returns
-------
pdf : array-like, shape=[n_samples, n_gaussians,]
Probability density function computed at each data
sample and for each component of the GMM.
"""
data_length, _, _ = data.shape + (self.means.shape,)

variances_expanded = gs.expand_dims(self.variances, 0)
variances_expanded = gs.repeat(variances_expanded, data_length, 0)

variances_flatten = variances_expanded.flatten()

distances = -(self.space.metric.dist_broadcast(data, self.means) ** 2)
distances = gs.reshape(distances, (data.shape * self.variances.shape,))

num = gs.exp(distances / (2 * variances_flatten**2))

den = self._compute_normalization_factor()

den = gs.expand_dims(den, 0)
den = gs.repeat(den, data_length, axis=0).flatten()

pdf = num / den
pdf = gs.reshape(pdf, (data.shape, self.means.shape))

return pdf

def _compute_normalization_factor(self):
"""Find the normalization factor given some variances.

Returns
-------
norm_factor : array-like, shape=[n_gaussians,]
Array of normalization factors for the given
variances.
"""
n_gaussians, precision = self.variances.shape, self.variances_range.shape

ref = gs.expand_dims(self.variances_range, 0)
ref = gs.repeat(ref, n_gaussians, axis=0)
val = gs.expand_dims(self.variances, 1)
val = gs.repeat(val, precision, axis=1)

difference = gs.abs(ref - val)

index = gs.argmin(difference, axis=-1)
norm_factor = self.normalization_factor_var[index]

return norm_factor

[docs]
def compute_variance_from_index(self, weighted_distances):
r"""Return the variance given weighted distances.

Parameters
----------
weighted_distances : array-like, shape=[n_gaussians,]
Mean of the weighted distances between training data
and current barycentres. The weights of each data sample
corresponds to the probability of belonging to a component
of the Gaussian mixture model.

Returns
-------
var : array-like, shape=[n_gaussians,]
Estimated variances for each component of the GMM.
"""
n_gaussians, precision = (
weighted_distances.shape,
self.variances_range.shape,
)

ref = gs.expand_dims(self.phi_inv_var, 0)
ref = gs.repeat(ref, n_gaussians, axis=0)

val = gs.expand_dims(weighted_distances, 1)
val = gs.repeat(val, precision, axis=1)

abs_difference = gs.abs(ref - val)

index = gs.argmin(abs_difference, -1)

var = self.variances_range[index]

return var

[docs]
def weighted_pdf(self, mixture_coefficients, mesh_data):
"""Return the probability density function of a GMM.

Parameters
----------
mixture_coefficients : array-like, shape=[n_gaussians,]
Coefficients of the Gaussian mixture model.
mesh_data : array-like, shape=[n_precision, dim]
Points at which the GMM probability density is computed.

Returns
-------
weighted_pdf : array-like, shape=[n_precision, n_gaussians,]
Probability density function computed for each point of
the mesh data, for each component of the GMM.
"""

variances_units = gs.expand_dims(self.variances, 0)
variances_units = gs.repeat(variances_units, distance_to_mean.shape, axis=0)

distribution_normal = gs.exp(
-(distance_to_mean**2) / (2 * variances_units**2)
)

zeta_sigma = (2 * gs.pi) ** (2 / 3) * self.variances
zeta_sigma = zeta_sigma * gs.exp(
(self.variances**2 / 2) * gs.erf(self.variances / gs.sqrt(2))
)

result_num = gs.expand_dims(mixture_coefficients, 0)
result_num = gs.repeat(result_num, len(distribution_normal), axis=0)
result_num = result_num * distribution_normal

result_denum = gs.expand_dims(zeta_sigma, 0)
result_denum = gs.repeat(result_denum, len(distribution_normal), axis=0)

weighted_pdf = result_num / result_denum

return weighted_pdf

[docs]
class RiemannianEM(TransformerMixin, ClusterMixin, BaseEstimator):
r"""Expectation-maximization algorithm.

A class for performing Expectation-Maximization to fit a Gaussian Mixture
Model (GMM) to data on a manifold. This method is only implemented for
the hypersphere and the Poincare ball.

Parameters
----------
space : Manifold
Equipped manifold.
n_gaussians : int
Number of Gaussian components in the mix.
initialisation_method : basestring
Optional, default: 'random'.
Choice between initialization method for variances, means and weights.
'random' : will select random uniformly train points as
initial cluster centers.
'kmeans' : will apply Riemannian kmeans to deduce
variances and means that the EM will use initially.
tol : float
Optional, default: 1e-2.
Convergence tolerance. If the difference of mean distance
between two steps is lower than tol.
max_iter : int
Maximum number of iterations for the gradient descent.
Optional, default: 100.

Attributes
----------
mixture_coefficients_ : array-like, shape=[n_gaussians,]
Weights for each GMM component.
variances_ : array-like, shape=[n_gaussians,]
Variances for each GMM component.
means_ : array-like, shape=[n_gaussian, _dimension]
Barycentre of each component of the GMM.

Example
-------
Available example on the Poincaré Ball manifold
:mod:examples.plot_expectation_maximization_ball
"""

def __init__(
self,
space,
n_gaussians=8,
initialisation_method="random",
tol=1e-2,
max_iter=100,
conv_rate=1e-4,
minimum_epochs=10,
):
self.space = space
self.n_gaussians = n_gaussians
self.initialisation_method = initialisation_method
self.tol = tol
self.max_iter = max_iter
self.conv_rate = conv_rate
self.minimum_epochs = minimum_epochs

self.mean_estimator = FrechetMean(
space,
method="batch",
).set(
max_iter=100,
epsilon=1e-4,
init_step_size=1.0,
)

self._model = GaussianMixtureModel(self.space)

self.mixture_coefficients_ = None

@property
def means_(self):
"""Means of each component of the GMM."""
return self._model.means

@property
def variances_(self):
"""Array of standard deviations."""
return self._model.variances

def _update_posterior_probabilities(self, posterior_probabilities):
"""Posterior probabilities update function.

Parameters
----------
posterior_probabilities : array-like, shape=[n_samples, n_gaussians,]
Probability of a given sample to belong to a component
of the GMM, computed for all components.
"""
self.mixture_coefficients_ = gs.mean(posterior_probabilities, 0)

if gs.any(gs.isnan(self.mixture_coefficients_)):
logging.warning(
"UPDATE : mixture coefficients contain elements that are not numbers"
)

def _update_means(self, data, posterior_probabilities):
"""Update means."""
n_gaussians = posterior_probabilities.shape[-1]

data_expand = gs.expand_dims(data, 1)
data_expand = gs.repeat(data_expand, n_gaussians, axis=1)

self.mean_estimator.fit(data_expand, weights=posterior_probabilities)
self._model.means = gs.squeeze(self.mean_estimator.estimate_)

if gs.any(gs.isnan(self._model.means)):
logging.warning("UPDATE : means contain not a number elements")

def _update_variances(self, data, posterior_probabilities):
"""Update variances function.

Parameters
----------
data : array-like, shape=[n_samples, n_features,]
Training data, where n_samples is the number of samples and
n_features is the number of features.
posterior_probabilities : array-like, shape=[n_samples, n_gaussians,]
Probability of a given sample to belong to a component
of the GMM, computed for all components.
"""
dist_means_data = self.space.metric.dist_broadcast(data, self._model.means) ** 2

weighted_dist_means_data = (dist_means_data * posterior_probabilities).sum(
0
) / posterior_probabilities.sum(0)

self._model.variances = self._model.compute_variance_from_index(
weighted_dist_means_data
)

if gs.any(gs.isnan(self._model.variances)):
logging.warning("UPDATE : variances contain not a number elements")

def _expectation(self, data):
"""Update the posterior probabilities.

Parameters
----------
data : array-like, shape=[n_samples, n_features]
Training data, where n_samples is the number of samples and
n_features is the number of features.
"""
pdf = self._model.pdf(data)

if gs.any(gs.isnan(pdf)):
logging.warning(
"EXPECTATION : Probability distribution function"
"contain elements that are not numbers"
)

num_normalized_pdf = gs.einsum("j,...j->...j", self.mixture_coefficients_, pdf)
valid_pdf_condition = gs.amin(gs.sum(num_normalized_pdf, -1))

if valid_pdf_condition <= PDF_TOL:
num_normalized_pdf[gs.sum(num_normalized_pdf, -1) <= PDF_TOL] = 1

sum_pdf = gs.sum(num_normalized_pdf, -1)
posterior_probabilities = gs.einsum(
"...i,...->...i", num_normalized_pdf, 1 / sum_pdf
)

if gs.any(gs.mean(posterior_probabilities)) is None:
logging.warning(
"EXPECTATION : posterior probabilities "
"contain elements that are not numbers."
)

if (
1 - SUM_CHECK_PDF
>= gs.mean(gs.sum(posterior_probabilities, 1))
>= 1 + SUM_CHECK_PDF
):
logging.warning("EXPECTATION : posterior probabilities " "do not sum to 1.")

if gs.any(gs.sum(posterior_probabilities, 0) < PDF_TOL):
logging.warning(
"EXPECTATION : Gaussian got no elements "
"(precision error) reinitialize"
)
posterior_probabilities[posterior_probabilities == 0] = PDF_TOL

return posterior_probabilities

def _maximization(self, data, posterior_probabilities):
"""Update function for the means and variances.

Parameters
----------
data : array-like, shape=[n_samples, n_features,]
Training data, where n_samples is the number of samples and
n_features is the number of features.
posterior_probabilities : array-like, shape=[n_samples, n_gaussians,]
Probability of a given sample to belong to a component
of the GMM, computed for all components.
"""
self._update_posterior_probabilities(posterior_probabilities)
self._update_means(data, posterior_probabilities)
self._update_variances(data, posterior_probabilities)

def _initialization(self, X):
if self.initialisation_method == "kmeans":
kmeans_estimator = RiemannianKMeans(
space=self.space,
n_clusters=self.n_gaussians,
init="random",
)

kmeans_estimator.fit(X=X)
cluster_centers = kmeans_estimator.cluster_centers_
labels = kmeans_estimator.labels_

means = cluster_centers
variances = gs.zeros(self.n_gaussians)

labeled_data = gs.vstack([labels, gs.transpose(X)])
labeled_data = gs.transpose(labeled_data)
for label, cluster_center in enumerate(cluster_centers):
label_mask = gs.where(labeled_data[:, 0] == label)
v = variance(self.space, grouped_by_label, cluster_center)
if grouped_by_label.shape == 1:
v += MIN_VAR_INIT
variances[label] = v
else:
dim = self.space.shape[-1]
means = (gs.random.rand(self.n_gaussians, dim) - 0.5) / dim
variances = gs.random.rand(self.n_gaussians) / 10 + 0.8

return means, variances

[docs]
def fit(self, X, y=None):
"""Fit a Gaussian mixture model (GMM) given the data.

Alternates between Expectation and Maximization steps
for some number of iterations.

Parameters
----------
X : array-like, shape=[n_samples, n_features]
Training data, where n_samples is the number of samples
and n_features is the number of features.
y : None
Target values. Ignored.

Returns
-------
self : object
Returns self.
"""
self._model.means, self._model.variances = self._initialization(X)

self.mixture_coefficients_ = gs.ones(self.n_gaussians) / self.n_gaussians
posterior_probabilities = gs.ones((X.shape, self.n_gaussians))

for epoch in range(self.max_iter):
old_posterior_probabilities = posterior_probabilities

posterior_probabilities = self._expectation(X)

condition = gs.mean(
gs.abs(old_posterior_probabilities - posterior_probabilities)
)

if condition < self.conv_rate and epoch > self.minimum_epochs:
logging.info("EM converged in %s iterations", epoch)
break

self._maximization(X, posterior_probabilities)
else:
logging.info(
"WARNING: EM did not converge \nPlease increase MINIMUM_EPOCHS."
)

return self

`