"""Multidimensional scaling (MDS)."""
from sklearn.manifold import MDS as _MDS
import geomstats.backend as gs
[docs]
def pairwise_dists(space, points):
"""Compute the pairwise distance between points.
Parameters
----------
space : Manifold or PointSet
points : array-like, shape=[n_samples, dim]
Set of points in the manifold.
Returns
-------
pairwise_dist_matrix : array-like, shape=[n_samples, n_samples]
Pairwise distance matrix between all the points.
"""
n_samples = len(points)
pairwise_dist_matrix = gs.zeros((n_samples, n_samples))
for i in range(n_samples - 1):
dists = space.metric.dist(points[i], points[i + 1 :])
pairwise_dist_matrix[i, i + 1 :] = dists
pairwise_dist_matrix[i + 1 :, i] = dists
return pairwise_dist_matrix
[docs]
class MDS(_MDS):
r"""Multidimensional scaling (MDS).
MDS is used to translate pairwise distances of N objects into
n points mapped into abstract Cartesian space, usually two-dimensional.
In general, MDS doesn't need a metric, though this implementation will use one.
Parameters
----------
space : Manifold or PointSet
Space equipped with a distance metric.
n_components : int
Number of dimensions in which to immerse the dissimilarities.
Attributes
----------
embedding_ : array-like, shape=[n_samples, n_components]
Data transformed in the new space.
Notes
-----
* For all other parameters see documentation in scikit-learn.
* Required metric methods for general case:
* `dist`
References
----------
This algorithm uses the scikit-learn library:
https://scikit-learn.org/stable/modules/generated/sklearn.manifold.MDS.html
"""
def __init__(
self,
space,
n_components=2,
metric_mds=True,
n_init=1,
init="classical_mds",
max_iter=300,
verbose=0,
eps=1e-06,
n_jobs=None,
random_state=None,
metric_params=None,
normalized_stress="auto",
):
self.space = space
super().__init__(
n_components=n_components,
metric_mds=metric_mds,
n_init=n_init,
init=init,
max_iter=max_iter,
verbose=verbose,
eps=eps,
n_jobs=n_jobs,
random_state=random_state,
metric="precomputed",
metric_params=metric_params,
normalized_stress=normalized_stress,
)