"""K-means clustering.
Lead author: Hadi Zaatiti.
"""
import logging
from random import randint
from scipy.stats import rv_discrete
from sklearn.base import BaseEstimator, ClusterMixin
import geomstats.backend as gs
from geomstats.learning._template import TransformerMixin
from geomstats.learning.frechet_mean import FrechetMean
[docs]
class RiemannianKMeans(TransformerMixin, ClusterMixin, BaseEstimator):
"""Class for k-means clustering on manifolds.
K-means algorithm using Riemannian manifolds.
Parameters
----------
space : Manifold
Equipped manifold.
n_clusters : int
Number of clusters (k value of the k-means).
Optional, default: 8.
init : str or callable or array-like, shape=[n_clusters, n_features]
How to initialize cluster centers at the beginning of the algorithm. The
choice 'random' will select training points as initial cluster centers
uniformly at random. The choice 'kmeans++' selects cluster centers
heuristically to improve the convergence rate. When providing an array
of shape ``(n_clusters, n_features)``, the cluster centers are chosen as the
rows of that array. When providing a callable, it receives as arguments
the argument ``X`` to :meth:`fit` and the number of cluster centers
``n_clusters`` and is expected to return an array as above.
Optional, default: 'random'.
tol : float
Convergence factor. Convergence is achieved when the difference of mean
distance between two steps is lower than tol.
Optional, default: 1e-2.
max_iter : int
Maximum number of iterations.
Optional, default: 100
verbose : int
If verbose > 0, information will be printed during learning.
Optional, default: 0.
Notes
-----
* Required metric methods: `dist`.
Example
-------
Available example on the Poincaré Ball and Hypersphere manifolds
:mod:`examples.plot_kmeans_manifolds`
"""
def __init__(
self,
space,
n_clusters=8,
init="random",
tol=1e-2,
max_iter=100,
verbose=0,
):
self.space = space
self.n_clusters = n_clusters
self.init = init
self.tol = tol
self.verbose = verbose
self.max_iter = max_iter
self.init_cluster_centers_ = None
self.mean_estimator = FrechetMean(space=space)
if isinstance(self.mean_estimator, FrechetMean):
self.mean_estimator.set(max_iter=100, init_step_size=1.0)
self.cluster_centers_ = None
self.labels_ = None
self.inertia_ = None
def _pick_init_cluster_centers(self, X):
n_samples = X.shape[0]
if isinstance(self.init, str):
if self.init == "kmeans++":
cluster_centers = [X[randint(0, n_samples - 1)]]
for i in range(self.n_clusters - 1):
dists = gs.array(
[
self.space.metric.dist(cluster_centers[j], X)
for j in range(i + 1)
]
)
dists_to_closest_cluster_center = gs.amin(dists, axis=0)
indices = gs.arange(n_samples)
weights = dists_to_closest_cluster_center / gs.sum(
dists_to_closest_cluster_center
)
index = rv_discrete(
values=(gs.to_numpy(indices), gs.to_numpy(weights))
).rvs()
cluster_centers.append(X[index])
elif self.init == "random":
cluster_centers = [
X[randint(0, n_samples - 1)] for i in range(self.n_clusters)
]
else:
raise ValueError(
f"Unknown initial cluster centers method '{self.init}'."
)
cluster_centers = gs.stack(cluster_centers, axis=0)
else:
if callable(self.init):
cluster_centers = self.init(X, self.n_clusters)
else:
cluster_centers = self.init
if cluster_centers.shape[0] != self.n_clusters:
raise ValueError("Need as many initial cluster centers as clusters.")
if cluster_centers.shape[1] != X.shape[1]:
raise ValueError(
"Dimensions of initial cluster centers and "
"training data do not match."
)
return cluster_centers
[docs]
def fit(self, X):
"""Provide cluster centers and data labels.
Alternate between computing the mean of each cluster
and labelling data according to the new positions of the cluster centers.
Parameters
----------
X : array-like, shape=[n_samples, n_features]
Training data, where n_samples is the number of samples and
n_features is the number of features.
Returns
-------
self : object
Returns self.
"""
n_samples = X.shape[0]
if self.verbose > 0:
logging.info("Initializing...")
cluster_centers = self._pick_init_cluster_centers(X)
self.init_cluster_centers_ = gs.copy(cluster_centers)
dists = [
gs.to_ndarray(self.space.metric.dist(cluster_centers[i], X), 2, 1)
for i in range(self.n_clusters)
]
dists = gs.hstack(dists)
self.labels_ = gs.argmin(dists, 1)
for index in range(self.max_iter):
if self.verbose > 0:
logging.info(f"Iteration {index}...")
old_cluster_centers = gs.copy(cluster_centers)
for i in range(self.n_clusters):
fold = X[self.labels_ == i]
if len(fold) > 0:
self.mean_estimator.fit(fold)
cluster_centers[i] = self.mean_estimator.estimate_
else:
cluster_centers[i] = X[randint(0, n_samples - 1)]
dists = [
gs.to_ndarray(self.space.metric.dist(cluster_centers[i], X), 2, 1)
for i in range(self.n_clusters)
]
dists = gs.hstack(dists)
self.labels_ = gs.argmin(dists, 1)
dists_to_closest_cluster_center = gs.amin(dists, 1)
self.inertia_ = gs.sum(dists_to_closest_cluster_center**2)
cluster_centers_distances = self.space.metric.dist(
old_cluster_centers, cluster_centers
)
if self.verbose > 0:
logging.info(
f"Convergence criterion at the end of iteration {index} "
f"is {gs.mean(cluster_centers_distances)}."
)
if gs.mean(cluster_centers_distances) < self.tol:
if self.verbose > 0:
logging.info(f"Convergence reached after {index} iterations.")
break
else:
logging.warning(
f"K-means maximum number of iterations {self.max_iter} reached. "
"The mean may be inaccurate."
)
self.cluster_centers_ = cluster_centers
return self
[docs]
def predict(self, X):
"""Predict the labels for each data point.
Label each data point with the cluster having the nearest
cluster center using metric distance.
Parameters
----------
X : array-like, shape[n_samples, n_features]
Input data.
Returns
-------
labels : array-like, shape=[n_samples,]
Array of predicted cluster indices for each sample.
"""
if self.cluster_centers_ is None:
raise RuntimeError("fit needs to be called first.")
dists = gs.stack(
[
self.space.metric.dist(cluster_center, X)
for cluster_center in self.cluster_centers_
],
axis=1,
)
dists = gs.squeeze(dists)
labels = gs.argmin(dists, -1)
return labels