Source code for geomstats.learning.knn
"""The KNN classifier on manifolds.
Lead author: Yann Cabanes.
"""
from sklearn.neighbors import KNeighborsClassifier
import geomstats.backend as gs
[docs]
def wrap(function):
"""Wrap a function to first convert args to arrays."""
def wrapped_function(*args, **kwargs):
new_args = map(gs.from_numpy, args)
return function(*new_args, **kwargs)
return wrapped_function
[docs]
class KNearestNeighborsClassifier(KNeighborsClassifier):
"""Classifier implementing the k-nearest neighbors vote on manifolds.
Parameters
----------
space : Manifold
Equipped manifold.
n_neighbors : int, optional (default = 5)
Number of neighbors to use by default.
weights : string or callable, optional (default = 'uniform')
Weight function used in prediction. Possible values:
- 'uniform' : uniform weights. All points in each neighborhood
are weighted equally.
- 'distance' : weight points by the inverse of their distance.
in this case, closer neighbors of a query point will have a
greater influence than neighbors which are further away.
- [callable] : a user-defined function which accepts an
array of distances, and returns an array of the same shape
containing the weights.
n_jobs : int or None, optional (default = None)
The number of parallel jobs to run for neighbors search.
``None`` means 1; ``-1`` means using all processors.
Attributes
----------
classes_ : array, shape=[n_classes,]
Class labels known to the classifier
effective_metric_ : string or callable
The distance metric used. It will be same as the `distance` parameter
or a synonym of it, e.g. 'euclidean' if the `distance` parameter set to
'minkowski' and `p` parameter set to 2.
effective_metric_params_ : dict
Additional keyword arguments for the distance function.
For most distances will be same with `distance_params` parameter,
but may also contain the `p` parameter value if the
`effective_metric_` attribute is set to 'minkowski'.
outputs_2d_ : bool
False when `y`'s shape is (n_samples, ) or (n_samples, 1) during fit
otherwise True.
References
----------
This algorithm uses the scikit-learn library:
https://github.com/scikit-learn/scikit-learn/blob/95d4f0841/sklearn/
neighbors/_classification.py#L25
"""
def __init__(
self,
space,
n_neighbors=5,
weights="uniform",
n_jobs=None,
):
self.space = space
distance = wrap(space.metric.dist)
super().__init__(
n_neighbors=n_neighbors,
weights=weights,
algorithm="brute",
metric=distance,
n_jobs=n_jobs,
)