Source code for geomstats.learning.geometric_median

"""Geometric median."""

import logging

from sklearn.base import BaseEstimator

import geomstats.backend as gs

[docs] class GeometricMedian(BaseEstimator): """Geometric median. Parameters ---------- space : Manifold Equipped manifold. max_iter : int Maximum number of iterations for the algorithm. Optional, default : 100 lr : float Learning rate to be used for the algorithm. Optional, default : 1.0 init_point : array-like, shape=[*space.shape] Initialization to be used in the start. Optional, default : None, in which case it uses last sample. print_every : int Print updated median after print_every iterations. Optional, default : None epsilon : float Tolerance for stopping the algorithm (distance between two successive estimates). Optional, default : gs.atol Attributes ---------- estimate_ : array-like, shape=[*space.shape] If fit, geometric median. Notes ----- * Required metric methods: `dist`, `log`, `exp`. References ---------- .. [FVJ2009] Fletcher PT, Venkatasubramanian S and Joshi S. "The geometric median on Riemannian manifolds with application to robust atlas estimation", NeuroImage, 2009 """ def __init__( self, space, max_iter=100, lr=1.0, init_point=None, print_every=None, epsilon=gs.atol, ): = space self.max_iter = max_iter = lr self.init_point = init_point self.print_every = print_every self.epsilon = epsilon self.estimate_ = None def _iterate_once(self, current_median, X, weights, lr): """Compute a single iteration of Weiszfeld algorithm. Parameters ---------- current_median : array-like, shape=[*space.shape] Current median. X : array-like, shape=[n_samples, *space.shape] Training input samples. weights : array-like, shape=[n_samples,] Weights for weighted sum. lr : float Learning rate for the current iteration. Returns ------- updated_median : array-like, shape=[*metric.shape] Updated median after single iteration. """ dists =, X) is_non_zero = dists > gs.atol if not gs.any(is_non_zero): return current_median w = weights[is_non_zero] / dists[is_non_zero] logs =[is_non_zero], current_median) v_k = gs.einsum("n,n...->...", w / gs.sum(w), logs) updated_median = * v_k, current_median) return updated_median
[docs] def fit(self, X, y=None, weights=None): """Compute the weighted geometric median. Compute the geometric median on manifold using Weiszfeld algorithm. Parameters ---------- X : array-like, shape=[n_samples, *metric.shape] Training input samples. y : None Target values. Ignored. weights : array-like, shape=[n_samples,] Weights associated to the samples. Optional, default: None, in which case it is equally weighted. Returns ------- self : object Returns self. """ median = X[-1] if self.init_point is None else self.init_point if weights is None: weights = gs.ones(X.shape[0]) weights /= gs.sum(weights) for iteration in range(self.max_iter): new_median = self._iterate_once(median, X, weights, shift =, median) if shift < self.epsilon: break median = new_median if self.print_every and (iteration + 1) % self.print_every == 0:"median at iteration {iteration+1}:\n{median}") else: logging.warning( "Maximum number of iterations %s reached. " "The median may be inaccurate", self.max_iter, ) self.estimate_ = median return self