Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

ENH: Update KDTree, and example documentation #25482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions 12 doc/modules/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,13 @@ have the same interface; we'll show an example of using the KD Tree here:
Refer to the :class:`KDTree` and :class:`BallTree` class documentation
for more information on the options available for nearest neighbors searches,
including specification of query strategies, distance metrics, etc. For a list
of available metrics, see the documentation of the :class:`DistanceMetric` class
and the metrics listed in `sklearn.metrics.pairwise.PAIRWISE_DISTANCE_FUNCTIONS`.
Note that the "cosine" metric uses :func:`~sklearn.metrics.pairwise.cosine_distances`.
of valid metrics use :meth:`KDTree.valid_metrics` and :meth:`BallTree.valid_metrics`:

>>> from sklearn.neighbors import KDTree, BallTree
>>> KDTree.valid_metrics()
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity']
>>> BallTree.valid_metrics()
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity', 'seuclidean', 'mahalanobis', 'wminkowski', 'hamming', 'canberra', 'braycurtis', 'matching', 'jaccard', 'dice', 'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath', 'haversine', 'pyfunc']

.. _classification:

Expand Down Expand Up @@ -476,7 +480,7 @@ A list of valid metrics for any of the above algorithms can be obtained by using
``valid_metric`` attribute. For example, valid metrics for ``KDTree`` can be generated by:

>>> from sklearn.neighbors import KDTree
>>> print(sorted(KDTree.valid_metrics))
>>> print(sorted(KDTree.valid_metrics()))
['chebyshev', 'cityblock', 'euclidean', 'infinity', 'l1', 'l2', 'manhattan', 'minkowski', 'p']


Expand Down
4 changes: 2 additions & 2 deletions 4 sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
SCIPY_METRICS += ["kulsinski"]

VALID_METRICS = dict(
ball_tree=BallTree.valid_metrics,
kd_tree=KDTree.valid_metrics,
ball_tree=BallTree._valid_metrics,
kd_tree=KDTree._valid_metrics,
# The following list comes from the
# sklearn.metrics.pairwise doc string
brute=sorted(set(PAIRWISE_DISTANCE_FUNCTIONS).union(SCIPY_METRICS)),
Expand Down
25 changes: 19 additions & 6 deletions 25 sklearn/neighbors/_binary_tree.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ leaf_size : positive int, default=40
metric : str or DistanceMetric object, default='minkowski'
Metric to use for distance computation. Default is "minkowski", which
results in the standard Euclidean distance when p = 2.
{binary_tree}.valid_metrics gives a list of the metrics which are valid for
{BinaryTree}. See the documentation of `scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and the
metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
more information.
A list of valid metrics for {BinaryTree} is given by
:meth:`{BinaryTree}.valid_metrics`.
See the documentation of `scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
more information on any distance metric.

Additional keywords are passed to the distance metric class.
Note: Callable functions in the metric parameter are NOT supported for KDTree
Expand Down Expand Up @@ -791,7 +791,7 @@ cdef class BinaryTree:
cdef int n_splits
cdef int n_calls

valid_metrics = VALID_METRIC_IDS
_valid_metrics = VALID_METRIC_IDS

# Use cinit to initialize all arrays to empty: this will prevent memory
# errors and seg-faults in rare cases where __init__ is not called
Expand Down Expand Up @@ -979,6 +979,19 @@ cdef class BinaryTree:
self.node_bounds.base,
)

@classmethod
def valid_metrics(cls):
"""Get list of valid distance metrics.

.. versionadded:: 1.3

Returns
-------
valid_metrics: list of str
List of valid distance metrics.
"""
return cls._valid_metrics

cdef inline DTYPE_t dist(self, DTYPE_t* x1, DTYPE_t* x2,
ITYPE_t size) nogil except -1:
"""Compute the distance between arrays x1 and x2"""
Expand Down
6 changes: 3 additions & 3 deletions 6 sklearn/neighbors/_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ def _choose_algorithm(self, algorithm, metric):
# algorithm to compute the result.
if algorithm == "auto":
# use KD Tree if possible
if metric in KDTree.valid_metrics:
if metric in KDTree.valid_metrics():
return "kd_tree"
elif metric in BallTree.valid_metrics:
elif metric in BallTree.valid_metrics():
return "ball_tree"
else: # kd_tree or ball_tree
if metric not in TREE_DICT[algorithm].valid_metrics:
if metric not in TREE_DICT[algorithm].valid_metrics():
raise ValueError(
"invalid metric for {0}: '{1}'".format(TREE_DICT[algorithm], metric)
)
Expand Down
4 changes: 2 additions & 2 deletions 4 sklearn/neighbors/tests/test_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_kde_algorithm_metric_choice(algorithm, metric):

kde = KernelDensity(algorithm=algorithm, metric=metric)

if algorithm == "kd_tree" and metric not in KDTree.valid_metrics:
if algorithm == "kd_tree" and metric not in KDTree.valid_metrics():
with pytest.raises(ValueError, match="invalid metric"):
kde.fit(X)
else:
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_kde_sample_weights():
test_points = rng.rand(n_samples_test, d)
for algorithm in ["auto", "ball_tree", "kd_tree"]:
for metric in ["euclidean", "minkowski", "manhattan", "chebyshev"]:
if algorithm != "kd_tree" or metric in KDTree.valid_metrics:
if algorithm != "kd_tree" or metric in KDTree.valid_metrics():
kde = KernelDensity(algorithm=algorithm, metric=metric)

# Test that adding a constant sample weight has no effect
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.