Source code for geomstats.learning.kmedoids
"""K-medoids clustering.
Lead author: Hadi Zaatiti.
"""
import logging
from sklearn.base import BaseEstimator, ClusterMixin
import geomstats.backend as gs
from geomstats.learning._template import TransformerMixin
[docs]
class RiemannianKMedoids(TransformerMixin, ClusterMixin, BaseEstimator):
"""Class for K-medoids clustering on manifolds.
K-medoids algorithm using Riemannian manifolds.
Parameters
----------
space : Manifold
Equipped manifold.
n_clusters : int
Number of clusters (k value of k-medoids).
Optional, default: 8.
max_iter : int
Maximum number of iterations.
Optional, default: 100.
init : str
How to initialize cluster centers at the beginning of the algorithm. The
choice 'random' will select training points as initial cluster centers
uniformly at random.
Optional, default: 'random'.
n_jobs : int
Number of jobs to run in parallel. `-1` means using all processors.
Optional, default: 1.
Notes
-----
* Required metric methods: `dist`, `dist_pairwise`.
Example
-------
Available example on the Poincaré Ball and Hypersphere manifolds
:mod:`examples.plot_kmedoids_manifolds`
"""
def __init__(self, space, n_clusters=8, init="random", max_iter=100, n_jobs=1):
self.space = space
self.n_clusters = n_clusters
self.max_iter = max_iter
self.init = init
self.n_jobs = n_jobs
self.cluster_centers_ = None
self.labels_ = None
self.medoid_indices_ = None
def _initialize_medoids(self, distances):
"""Select initial medoids when beginning clustering."""
if self.init == "random":
medoids = gs.random.choice(gs.arange(len(distances)), self.n_clusters)
else:
logging.error("Unknown initialization method.")
return medoids
[docs]
def fit(self, X):
"""Provide cluster centers and data labels.
Labels data by minimizing the distance between data points
and cluster center chosen from the data points.
Minimization is performed by swapping the cluster centers and data points.
Parameters
----------
X : array-like, shape=[n_samples, dim]
Training data, where n_samples is the number of samples and
dim is the number of dimensions.
Returns
-------
self : object
Returns self.
"""
distances = self.space.metric.dist_pairwise(X, n_jobs=self.n_jobs)
medoids_indices = self._initialize_medoids(distances)
for iteration in range(self.max_iter):
old_medoids_indices = gs.copy(medoids_indices)
labels = gs.argmin(distances[medoids_indices, :], axis=0)
self._update_medoid_indexes(distances, labels, medoids_indices)
if gs.all(old_medoids_indices == medoids_indices):
break
if iteration == self.max_iter - 1:
logging.warning(
"Maximum number of iteration reached before "
"convergence. Consider increasing max_iter to "
"improve the fit."
)
self.cluster_centers_ = X[medoids_indices]
self.labels_ = labels
self.medoid_indices_ = medoids_indices
return self
def _update_medoid_indexes(self, distances, labels, medoid_indices):
for cluster in range(self.n_clusters):
cluster_index = gs.where(labels == cluster)[0]
if len(cluster_index) == 0:
logging.warning("One cluster is empty.")
continue
in_cluster_distances = distances[
cluster_index, gs.expand_dims(cluster_index, axis=-1)
]
in_cluster_all_costs = gs.sum(in_cluster_distances, axis=1)
min_cost_index = gs.argmin(in_cluster_all_costs)
min_cost = in_cluster_all_costs[min_cost_index]
current_cost = in_cluster_all_costs[
gs.argmax(cluster_index == medoid_indices[cluster])
]
if min_cost < current_cost:
medoid_indices[cluster] = cluster_index[min_cost_index]
[docs]
def predict(self, X):
"""Predict the closest cluster for each sample in X.
Parameters
----------
X : array-like, shape=[n_samples, dim,]
Training data, where n_samples is the number of samples and
dim is the number of dimensions.
Returns
-------
labels : array-like, shape=[n_samples,]
Index of the cluster each sample belongs to.
"""
labels = gs.zeros(len(X))
for point_index, point_value in enumerate(X):
distances = gs.zeros(len(self.cluster_centers_))
for cluster_index, cluster_value in enumerate(self.cluster_centers_):
distances[cluster_index] = self.space.metric.dist(
point_value, cluster_value
)
labels[point_index] = gs.argmin(distances)
return labels