Source code for geomstats.learning.riemannian_mean_shift

"""Riemannian mean-shift clustering.

Lead author: Nina Miolane and Shubham Talbar.
"""

import joblib
from sklearn.base import BaseEstimator, ClusterMixin

import geomstats.backend as gs
from geomstats.learning.frechet_mean import FrechetMean


[docs] class RiemannianMeanShift(ClusterMixin, BaseEstimator): """Class for Riemannian Mean Shift algorithm on manifolds. Mean Shift is a procedure for locating the maxima - the modes of a density function given discrete data sampled from that function. It is an iterative method for finding the centers of a collection of clusters. Following implementation assumes a flat kernel method. Parameters ---------- space : Manifold Equipped manifold. bandwidth : float Size of neighbourhood around each center. All points in 'bandwidth' size around center are considered for calculating new mean centers. tol : float Stopping condition. Computation of subsequent mean centers is stopped when the distance between them is less than 'tol'. Optional, default : 1e-2. n_clusters : int Number of centers. Optional, default : 1. n_jobs : int Number of parallel threads to be initiated for parallel jobs. Optional, default : 1. max_iter : int Upper bound on total number of iterations for the centers to converge. Optional, default : 100. init_centers : str Initializing centers, either from the given input points or random points uniformly distributed in the input manifold. Optional, default : "from_points". kernel : str Weighing function to assign kernel weights to each center. Optional, default : "flat". Notes ----- * Required metric methods: `dist`, `closest_neighbor_index`. """ def __init__( self, space, bandwidth, tol=1e-2, n_clusters=1, n_jobs=1, max_iter=100, init_centers="from_points", kernel="flat", ): self.space = space self.bandwidth = bandwidth self.tol = tol self.n_clusters = n_clusters self.n_jobs = n_jobs self.max_iter = max_iter self.init_centers = init_centers self.kernel = kernel self.cluster_centers_ = None self.mean_estimator = FrechetMean(space) def _dist_intersets(self, points_a, points_b): """Parallel computation of distances between two sets of points. Parameters ---------- points_a : array-like, shape=[..., n_features] Clusters of points. points_b : array-like, shape=[..., n_features] Clusters of points. """ n_a, n_b = points_a.shape[0], points_b.shape[0] @joblib.delayed @joblib.wrap_non_picklable_objects def pickable_dist(x, y): """Riemannian distance between points x & y. Parameters ---------- x : single point, shape=[1, n_features] Single point on manifold. y : array-like, shape=[1, n_features] Single point on manifold. """ return self.space.metric.dist(x, y) pool = joblib.Parallel(n_jobs=self.n_jobs) out = pool( pickable_dist(point_a, point_b) for point_a in points_a for point_b in points_b ) return gs.array(out).reshape((n_a, n_b)) def _initialization(self, X): if self.init_centers == "from_points": n_points = X.shape[0] centers = X[gs.random.randint(n_points, size=(self.n_clusters,)), :] elif self.init_centers == "random_uniform": centers = self.space.random_uniform(n_samples=self.n_clusters) return centers
[docs] def fit(self, X, y=None): """Fit centers in all the input points. Parameters ---------- X : array-like, shape=[n_samples, n_features] Clusters of points. y : None Target values. Ignored. Returns ------- self : object Returns self. """ @joblib.delayed @joblib.wrap_non_picklable_objects def pickable_mean(points, weights): """Frechet Mean of all points weighted by weights. Parameters ---------- points : array-like, shape=[..., n_features] Clusters of points. weights : array-like, Weight associated with each point in cluster. """ return self.mean_estimator.fit(points, weights=weights).estimate_ centers = self._initialization(X) for _ in range(self.max_iter): dists = self._dist_intersets(centers, X) if self.kernel == "flat": weights = gs.ones_like(dists) weights[dists > self.bandwidth] = 0.0 weights = weights / gs.sum(weights, axis=1, keepdims=True) points_to_average, nonzero_weights = [], [] for j in range(self.n_clusters): indexes = gs.where(weights[j] > 0) nonzero_weights += [ weights[j][indexes], ] points_to_average += [ X[indexes], ] pool = joblib.Parallel(n_jobs=self.n_jobs) out = pool( pickable_mean(points_to_average[j], nonzero_weights[j]) for j in range(self.n_clusters) ) new_centers = gs.array(out) displacements = [self.space.metric.dist(centers, new_centers)] centers = new_centers if (gs.array(displacements) < self.tol).all(): break self.cluster_centers_ = centers return self
[docs] def predict(self, X): """Predict the closest cluster each point in `points` belongs to. Parameters ---------- points : array-like, shape=[n_samples, n_features] Clusters of points. """ if self.cluster_centers_ is None: raise Exception("Not fitted") return self.space.metric.closest_neighbor_index(X, self.cluster_centers_)