"""Expectation maximization algorithm.
Lead authors: Thomas Gerald and Hadi Zaatiti.
"""
import logging
from sklearn.base import BaseEstimator, ClusterMixin
import geomstats.backend as gs
from geomstats.learning._template import TransformerMixin
from geomstats.learning.frechet_mean import FrechetMean, variance
from geomstats.learning.kmeans import RiemannianKMeans
PDF_TOL = 1e-6
SUM_CHECK_PDF = 1e-4
MIN_VAR_INIT = 1e-3
[docs]
class GaussianMixtureModel:
r"""Gaussian mixture model (GMM).
Parameters
----------
space : Manifold
Equipped manifold.
means : array-like, shape=[n_gaussians, dim]
Means of each component of the GMM.
variances : array-like, shape=[n_gaussians,]
Variances of each component of the GMM.
Attributes
----------
normalization_factor_var : array-like, shape=[n_variances,]
Array of computed normalization factor.
variances_range : array-like, shape=[n_variances,]
Array of standard deviations.
phi_inv_var : array-like, shape=[n_variances,]
Array of the computed inverse of a function phi
whose expression is closed-form
:math:`\sigma\mapsto \sigma^3 \times \frac{d}
{\mathstrut d\sigma}\log \zeta_m(\sigma)`
where :math:`\sigma` denotes the variance
and :math:`\zeta` the normalization coefficient
and :math:`m` the dimension.
"""
def __init__(
self,
space,
means=None,
variances=None,
zeta_lower_bound=5e-1,
zeta_upper_bound=2.0,
zeta_step=0.01,
):
self.space = space
self.means = means
self.variances = variances
self.zeta_lower_bound = zeta_lower_bound
self.zeta_upper_bound = zeta_upper_bound
self.zeta_step = zeta_step
(
self.variances_range,
self.normalization_factor_var,
self.phi_inv_var,
) = self._normalization_factor_init()
def _normalization_factor_init(self):
r"""Set up function for the normalization factor.
The normalization factor is used to define Gaussian distributions
at initialization.
"""
variances = gs.arange(
self.zeta_lower_bound, self.zeta_upper_bound, self.zeta_step
)
normalization_factor_var = self.space.metric.normalization_factor(variances)
cond_1 = normalization_factor_var.sum() != normalization_factor_var.sum()
cond_2 = normalization_factor_var.sum() == float("+inf")
cond_3 = normalization_factor_var.sum() == float("-inf")
if cond_1 or cond_2 or cond_3:
logging.warning("Untractable normalization factor :")
limit_nf = (
((normalization_factor_var / normalization_factor_var) * 0)
.nonzero()[0]
.item()
)
max_nf = len(variances)
variances = variances[0:limit_nf]
normalization_factor_var = normalization_factor_var[0:limit_nf]
if cond_1:
logging.warning("\t Nan value " "in processing normalization factor")
if cond_2 or cond_3:
raise ValueError("\t +-inf value in " "processing normalization factor")
logging.warning("\t Max variance is now : %s", str(variances[-1]))
logging.warning(
"\t Number of possible variance is now: %s / %s ",
str(len(variances)),
str(max_nf),
)
_, log_grad_zeta = self.space.metric.norm_factor_gradient(variances)
phi_inv_var = variances**3 * log_grad_zeta
return variances, normalization_factor_var, phi_inv_var
[docs]
def pdf(self, data):
"""Return the separate probability density function of GMM.
The probability density function is computed for
each component of the GMM separately (i.e., mixture coefficients
are not taken into account).
Parameters
----------
data : array-like, shape=[n_samples, dim]
Points at which the GMM probability density is computed.
Returns
-------
pdf : array-like, shape=[n_samples, n_gaussians,]
Probability density function computed at each data
sample and for each component of the GMM.
"""
data_length, _, _ = data.shape + (self.means.shape[0],)
variances_expanded = gs.expand_dims(self.variances, 0)
variances_expanded = gs.repeat(variances_expanded, data_length, 0)
variances_flatten = variances_expanded.flatten()
distances = -(self.space.metric.dist_broadcast(data, self.means) ** 2)
distances = gs.reshape(distances, (data.shape[0] * self.variances.shape[0],))
num = gs.exp(distances / (2 * variances_flatten**2))
den = self._compute_normalization_factor()
den = gs.expand_dims(den, 0)
den = gs.repeat(den, data_length, axis=0).flatten()
pdf = num / den
pdf = gs.reshape(pdf, (data.shape[0], self.means.shape[0]))
return pdf
def _compute_normalization_factor(self):
"""Find the normalization factor given some variances.
Returns
-------
norm_factor : array-like, shape=[n_gaussians,]
Array of normalization factors for the given
variances.
"""
n_gaussians, precision = self.variances.shape[0], self.variances_range.shape[0]
ref = gs.expand_dims(self.variances_range, 0)
ref = gs.repeat(ref, n_gaussians, axis=0)
val = gs.expand_dims(self.variances, 1)
val = gs.repeat(val, precision, axis=1)
difference = gs.abs(ref - val)
index = gs.argmin(difference, axis=-1)
norm_factor = self.normalization_factor_var[index]
return norm_factor
[docs]
def compute_variance_from_index(self, weighted_distances):
r"""Return the variance given weighted distances.
Parameters
----------
weighted_distances : array-like, shape=[n_gaussians,]
Mean of the weighted distances between training data
and current barycentres. The weights of each data sample
corresponds to the probability of belonging to a component
of the Gaussian mixture model.
Returns
-------
var : array-like, shape=[n_gaussians,]
Estimated variances for each component of the GMM.
"""
n_gaussians, precision = (
weighted_distances.shape[0],
self.variances_range.shape[0],
)
ref = gs.expand_dims(self.phi_inv_var, 0)
ref = gs.repeat(ref, n_gaussians, axis=0)
val = gs.expand_dims(weighted_distances, 1)
val = gs.repeat(val, precision, axis=1)
abs_difference = gs.abs(ref - val)
index = gs.argmin(abs_difference, -1)
var = self.variances_range[index]
return var
[docs]
def weighted_pdf(self, mixture_coefficients, mesh_data):
"""Return the probability density function of a GMM.
Parameters
----------
mixture_coefficients : array-like, shape=[n_gaussians,]
Coefficients of the Gaussian mixture model.
mesh_data : array-like, shape=[n_precision, dim]
Points at which the GMM probability density is computed.
Returns
-------
weighted_pdf : array-like, shape=[n_precision, n_gaussians,]
Probability density function computed for each point of
the mesh data, for each component of the GMM.
"""
distance_to_mean = self.space.metric.dist_broadcast(mesh_data, self.means)
variances_units = gs.expand_dims(self.variances, 0)
variances_units = gs.repeat(variances_units, distance_to_mean.shape[0], axis=0)
distribution_normal = gs.exp(-(distance_to_mean**2) / (2 * variances_units**2))
zeta_sigma = (2 * gs.pi) ** (2 / 3) * self.variances
zeta_sigma = zeta_sigma * gs.exp(
(self.variances**2 / 2) * gs.erf(self.variances / gs.sqrt(2))
)
result_num = gs.expand_dims(mixture_coefficients, 0)
result_num = gs.repeat(result_num, len(distribution_normal), axis=0)
result_num = result_num * distribution_normal
result_denum = gs.expand_dims(zeta_sigma, 0)
result_denum = gs.repeat(result_denum, len(distribution_normal), axis=0)
weighted_pdf = result_num / result_denum
return weighted_pdf
[docs]
class RiemannianEM(TransformerMixin, ClusterMixin, BaseEstimator):
r"""Expectation-maximization algorithm.
A class for performing Expectation-Maximization to fit a Gaussian Mixture
Model (GMM) to data on a manifold. This method is only implemented for
the hypersphere and the Poincare ball.
Parameters
----------
space : Manifold
Equipped manifold.
n_gaussians : int
Number of Gaussian components in the mix.
initialisation_method : basestring
Optional, default: 'random'.
Choice between initialization method for variances, means and weights.
- 'random' : will select random uniformly train points as
initial cluster centers.
- 'kmeans' : will apply Riemannian kmeans to deduce
variances and means that the EM will use initially.
tol : float
Optional, default: 1e-2.
Convergence tolerance. If the difference of mean distance
between two steps is lower than tol.
max_iter : int
Maximum number of iterations for the gradient descent.
Optional, default: 100.
Attributes
----------
mixture_coefficients_ : array-like, shape=[n_gaussians,]
Weights for each GMM component.
variances_ : array-like, shape=[n_gaussians,]
Variances for each GMM component.
means_ : array-like, shape=[n_gaussian, _dimension]
Barycentre of each component of the GMM.
Example
-------
Available example on the Poincaré Ball manifold
:mod:`examples.plot_expectation_maximization_ball`
"""
def __init__(
self,
space,
n_gaussians=8,
initialisation_method="random",
tol=1e-2,
max_iter=100,
conv_rate=1e-4,
minimum_epochs=10,
):
self.space = space
self.n_gaussians = n_gaussians
self.initialisation_method = initialisation_method
self.tol = tol
self.max_iter = max_iter
self.conv_rate = conv_rate
self.minimum_epochs = minimum_epochs
self.mean_estimator = FrechetMean(space)
if isinstance(self.mean_estimator, FrechetMean):
self.mean_estimator.method = "batch"
self.mean_estimator.set(
max_iter=100,
epsilon=1e-4,
init_step_size=1.0,
)
self._model = GaussianMixtureModel(self.space)
self.mixture_coefficients_ = None
@property
def means_(self):
"""Means of each component of the GMM."""
return self._model.means
@property
def variances_(self):
"""Array of standard deviations."""
return self._model.variances
def _update_posterior_probabilities(self, posterior_probabilities):
"""Posterior probabilities update function.
Parameters
----------
posterior_probabilities : array-like, shape=[n_samples, n_gaussians,]
Probability of a given sample to belong to a component
of the GMM, computed for all components.
"""
self.mixture_coefficients_ = gs.mean(posterior_probabilities, 0)
if gs.any(gs.isnan(self.mixture_coefficients_)):
logging.warning(
"UPDATE : mixture coefficients contain elements that are not numbers"
)
def _update_means(self, data, posterior_probabilities):
"""Update means."""
n_gaussians = posterior_probabilities.shape[-1]
data_expand = gs.expand_dims(data, 1)
data_expand = gs.repeat(data_expand, n_gaussians, axis=1)
self.mean_estimator.fit(data_expand, weights=posterior_probabilities)
self._model.means = gs.squeeze(self.mean_estimator.estimate_)
if gs.any(gs.isnan(self._model.means)):
logging.warning("UPDATE : means contain not a number elements")
def _update_variances(self, data, posterior_probabilities):
"""Update variances function.
Parameters
----------
data : array-like, shape=[n_samples, n_features,]
Training data, where n_samples is the number of samples and
n_features is the number of features.
posterior_probabilities : array-like, shape=[n_samples, n_gaussians,]
Probability of a given sample to belong to a component
of the GMM, computed for all components.
"""
dist_means_data = self.space.metric.dist_broadcast(data, self._model.means) ** 2
weighted_dist_means_data = (dist_means_data * posterior_probabilities).sum(
0
) / posterior_probabilities.sum(0)
self._model.variances = self._model.compute_variance_from_index(
weighted_dist_means_data
)
if gs.any(gs.isnan(self._model.variances)):
logging.warning("UPDATE : variances contain not a number elements")
def _expectation(self, data):
"""Update the posterior probabilities.
Parameters
----------
data : array-like, shape=[n_samples, n_features]
Training data, where n_samples is the number of samples and
n_features is the number of features.
"""
pdf = self._model.pdf(data)
if gs.any(gs.isnan(pdf)):
logging.warning(
"EXPECTATION : Probability distribution function"
"contain elements that are not numbers"
)
num_normalized_pdf = gs.einsum("j,...j->...j", self.mixture_coefficients_, pdf)
valid_pdf_condition = gs.amin(gs.sum(num_normalized_pdf, -1))
if valid_pdf_condition <= PDF_TOL:
num_normalized_pdf[gs.sum(num_normalized_pdf, -1) <= PDF_TOL] = 1
sum_pdf = gs.sum(num_normalized_pdf, -1)
posterior_probabilities = gs.einsum(
"...i,...->...i", num_normalized_pdf, 1 / sum_pdf
)
if gs.any(gs.mean(posterior_probabilities)) is None:
logging.warning(
"EXPECTATION : posterior probabilities "
"contain elements that are not numbers."
)
if (
1 - SUM_CHECK_PDF
>= gs.mean(gs.sum(posterior_probabilities, 1))
>= 1 + SUM_CHECK_PDF
):
logging.warning("EXPECTATION : posterior probabilities " "do not sum to 1.")
if gs.any(gs.sum(posterior_probabilities, 0) < PDF_TOL):
logging.warning(
"EXPECTATION : Gaussian got no elements "
"(precision error) reinitialize"
)
posterior_probabilities[posterior_probabilities == 0] = PDF_TOL
return posterior_probabilities
def _maximization(self, data, posterior_probabilities):
"""Update function for the means and variances.
Parameters
----------
data : array-like, shape=[n_samples, n_features,]
Training data, where n_samples is the number of samples and
n_features is the number of features.
posterior_probabilities : array-like, shape=[n_samples, n_gaussians,]
Probability of a given sample to belong to a component
of the GMM, computed for all components.
"""
self._update_posterior_probabilities(posterior_probabilities)
self._update_means(data, posterior_probabilities)
self._update_variances(data, posterior_probabilities)
def _initialization(self, X):
if self.initialisation_method == "kmeans":
kmeans_estimator = RiemannianKMeans(
space=self.space,
n_clusters=self.n_gaussians,
init="random",
)
kmeans_estimator.fit(X=X)
cluster_centers = kmeans_estimator.cluster_centers_
labels = kmeans_estimator.labels_
means = cluster_centers
variances = gs.zeros(self.n_gaussians)
labeled_data = gs.vstack([labels, gs.transpose(X)])
labeled_data = gs.transpose(labeled_data)
for label, cluster_center in enumerate(cluster_centers):
label_mask = gs.where(labeled_data[:, 0] == label)
grouped_by_label = labeled_data[label_mask][:, 1:]
v = variance(self.space, grouped_by_label, cluster_center)
if grouped_by_label.shape[0] == 1:
v += MIN_VAR_INIT
variances[label] = v
else:
dim = self.space.shape[-1]
means = (gs.random.rand(self.n_gaussians, dim) - 0.5) / dim
variances = gs.random.rand(self.n_gaussians) / 10 + 0.8
return means, variances
[docs]
def fit(self, X, y=None):
"""Fit a Gaussian mixture model (GMM) given the data.
Alternates between Expectation and Maximization steps
for some number of iterations.
Parameters
----------
X : array-like, shape=[n_samples, n_features]
Training data, where n_samples is the number of samples
and n_features is the number of features.
y : None
Target values. Ignored.
Returns
-------
self : object
Returns self.
"""
self._model.means, self._model.variances = self._initialization(X)
self.mixture_coefficients_ = gs.ones(self.n_gaussians) / self.n_gaussians
posterior_probabilities = gs.ones((X.shape[0], self.n_gaussians))
for epoch in range(self.max_iter):
old_posterior_probabilities = posterior_probabilities
posterior_probabilities = self._expectation(X)
condition = gs.mean(
gs.abs(old_posterior_probabilities - posterior_probabilities)
)
if condition < self.conv_rate and epoch > self.minimum_epochs:
logging.info("EM converged in %s iterations", epoch)
break
self._maximization(X, posterior_probabilities)
else:
logging.info(
"WARNING: EM did not converge \nPlease increase MINIMUM_EPOCHS."
)
return self