Source code for geomstats.learning.mds

"""Multidimensional scaling (MDS)."""

import sklearn.manifold._mds as _mds
import sklearn.metrics.pairwise as mp
from sklearn.manifold import MDS as _MDS

import geomstats.backend as gs
from geomstats._sklearngs.manifold._classical_mds import ClassicalMDS
from geomstats._sklearngs.manifold._mds import smacof
from geomstats.geometry.manifold import Manifold

from ._sklearn import (
    SklearnInteropMixin,
    check_array_allow_nd,
    validate_data_skip_check_array,
)


[docs] def pairwise_dists(points, dist_fnc): """Compute the pairwise distance between points. Parameters ---------- points : array-like, shape=[n_samples, dim] Set of points in the manifold. space : Manifold or PointSet Returns ------- pairwise_dist_matrix : array-like, shape=[n_samples, n_samples] Pairwise distance matrix between all the points. """ n_samples = len(points) pairwise_dist_matrix = gs.zeros((n_samples, n_samples)) for i in range(n_samples - 1): dists = dist_fnc(points[i], points[i + 1 :]) pairwise_dist_matrix[i, i + 1 :] = dists pairwise_dist_matrix[i + 1 :, i] = dists return pairwise_dist_matrix
[docs] class MDS(SklearnInteropMixin, _MDS): r"""Multidimensional scaling (MDS). MDS is used to translate pairwise distances of N objects into n points mapped into abstract Cartesian space, usually two-dimensional. In general, MDS doesn't need a metric, though this implementation will use one. Parameters ---------- space : Manifold or PointSet Space equipped with a distance metric. n_components : int Number of dimensions in which to immerse the dissimilarities. Attributes ---------- embedding_ : array-like, shape=[n_samples, n_components] Data transformed in the new space. Notes ----- * For all other parameters see documentation in scikit-learn. * Required metric methods for general case: * `dist` References ---------- This algorithm uses the scikit-learn library: https://scikit-learn.org/stable/modules/generated/sklearn.manifold.MDS.html """ _patched_methods = { "fit_transform", } def __init__( self, space, n_components=2, metric_mds=True, n_init=1, init="classical_mds", max_iter=300, verbose=0, eps=1e-06, n_jobs=None, random_state=None, metric_params=None, normalized_stress="auto", ): self.space = space self._set_interop(space) super().__init__( n_components=n_components, metric_mds=metric_mds, n_init=n_init, init=init, max_iter=max_iter, verbose=verbose, eps=eps, n_jobs=n_jobs, random_state=random_state, metric=space.metric.dist, metric_params=metric_params, normalized_stress=normalized_stress, ) def _set_interop(self, space): array_repr = isinstance(space, Manifold) backend_fix = gs.__name__.endswith("pytorch") self._use_sklearn_patches = ( not array_repr or space.point_ndim > 1 or backend_fix ) if not self._use_sklearn_patches: return patches = [] if backend_fix: patches.extend( [ (_mds, "smacof", smacof), (_mds, "ClassicalMDS", ClassicalMDS), ] ) if array_repr and space.point_ndim > 1: patches.extend( [ (_mds, "validate_data", validate_data_skip_check_array), (mp, "check_array", check_array_allow_nd), ] ) if not array_repr: patches.extend( [ (_mds, "validate_data", validate_data_skip_check_array), ( _mds, "pairwise_distances", lambda X, metric: pairwise_dists(X, metric), ), ] ) self._sklearn_patches = tuple(patches)