Source code for geomstats.learning.knn
"""The KNN classifier on manifolds."""
import sklearn.metrics.pairwise as mp
import sklearn.neighbors._base as nb
from sklearn.neighbors import KNeighborsClassifier
from geomstats.geometry.manifold import Manifold
from ._sklearn import (
OutputToBackendMixin,
SklearnInteropMixin,
check_array_allow_nd,
validate_data_skip_check_array,
)
[docs]
class KNearestNeighborsClassifier(
OutputToBackendMixin, SklearnInteropMixin, KNeighborsClassifier
):
"""Classifier implementing the k-nearest neighbors vote on manifolds.
Parameters
----------
space : Manifold
Equipped manifold.
n_neighbors : int, optional (default = 5)
Number of neighbors to use by default.
weights : string or callable, optional (default = 'uniform')
Weight function used in prediction. Possible values:
- 'uniform' : uniform weights. All points in each neighborhood
are weighted equally.
- 'distance' : weight points by the inverse of their distance.
in this case, closer neighbors of a query point will have a
greater influence than neighbors which are further away.
- [callable] : a user-defined function which accepts an
array of distances, and returns an array of the same shape
containing the weights.
n_jobs : int or None, optional (default = None)
The number of parallel jobs to run for neighbors search.
``None`` means 1; ``-1`` means using all processors.
Attributes
----------
classes_ : array, shape=[n_classes,]
Class labels known to the classifier
effective_metric_ : string or callable
The distance metric used. It will be same as the `distance` parameter
or a synonym of it, e.g. 'euclidean' if the `distance` parameter set to
'minkowski' and `p` parameter set to 2.
effective_metric_params_ : dict
Additional keyword arguments for the distance function.
For most distances will be same with `distance_params` parameter,
but may also contain the `p` parameter value if the
`effective_metric_` attribute is set to 'minkowski'.
outputs_2d_ : bool
False when `y`'s shape is (n_samples, ) or (n_samples, 1) during fit
otherwise True.
References
----------
This algorithm uses the scikit-learn library:
https://github.com/scikit-learn/scikit-learn/blob/95d4f0841/sklearn/neighbors/_classification.py#L25
"""
_output_to_backend_methods = (
"kneighbors",
"predict",
"predict_proba",
)
_patched_methods = {
"fit",
"kneighbors",
}
def __init__(
self,
space,
n_neighbors=5,
weights="uniform",
n_jobs=None,
):
self.space = space
self._set_interop(space)
super().__init__(
n_neighbors=n_neighbors,
weights=weights,
algorithm="brute",
metric=space.metric.dist,
n_jobs=n_jobs,
)
def _set_interop(self, space):
array_repr = isinstance(space, Manifold)
self._use_sklearn_patches = not array_repr or space.point_ndim > 1
if not self._use_sklearn_patches:
return
if array_repr:
patches = [
(nb, "validate_data", validate_data_skip_check_array),
(mp, "check_array", check_array_allow_nd),
]
else:
patches = [
(nb, "validate_data", validate_data_skip_check_array),
]
self._sklearn_patches = tuple(patches)