Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 37b61e3

Browse filesBrowse files
committed
Plug PairwiseDistancesRadiusNeighborhood as a back-end
Also move the error message upfront if results have to be sorted without the distances being returned.
1 parent ac6f623 commit 37b61e3
Copy full SHA for 37b61e3

File tree

Expand file treeCollapse file tree

1 file changed

+36
-11
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+36
-11
lines changed

‎sklearn/neighbors/_base.py

Copy file name to clipboardExpand all lines: sklearn/neighbors/_base.py
+36-11Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..base import BaseEstimator, MultiOutputMixin
2323
from ..base import is_classifier
2424
from ..metrics import pairwise_distances_chunked
25+
from ..metrics._pairwise_distances_reduction import PairwiseDistancesRadiusNeighborhood
2526
from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS
2627
from ..utils import (
2728
check_array,
@@ -1061,25 +1062,53 @@ class from an array representing our data set and ask who's
10611062
"""
10621063
check_is_fitted(self)
10631064

1064-
if X is not None:
1065-
query_is_train = False
1065+
if sort_results and not return_distance:
1066+
raise ValueError("return_distance must be True if sort_results is True.")
1067+
1068+
query_is_train = X is None
1069+
if query_is_train:
1070+
X = self._fit_X
1071+
else:
10661072
if self.metric == "precomputed":
10671073
X = _check_precomputed(X)
10681074
else:
1069-
X = self._validate_data(X, accept_sparse="csr", reset=False)
1070-
else:
1071-
query_is_train = True
1072-
X = self._fit_X
1075+
X = self._validate_data(X, accept_sparse="csr", reset=False, order="C")
10731076

10741077
if radius is None:
10751078
radius = self.radius
10761079

1077-
if self._fit_method == "brute" and self.metric == "precomputed" and issparse(X):
1080+
use_pairwise_distances_reductions = (
1081+
self._fit_method == "brute"
1082+
and PairwiseDistancesRadiusNeighborhood.is_usable_for(
1083+
X if X is not None else self._fit_X, self._fit_X, self.effective_metric_
1084+
)
1085+
)
1086+
1087+
if use_pairwise_distances_reductions:
1088+
results = PairwiseDistancesRadiusNeighborhood.compute(
1089+
X=X,
1090+
Y=self._fit_X,
1091+
radius=radius,
1092+
metric=self.effective_metric_,
1093+
metric_kwargs=self.effective_metric_params_,
1094+
n_threads=self.n_jobs,
1095+
strategy="auto",
1096+
return_distance=return_distance,
1097+
sort_results=sort_results,
1098+
)
1099+
1100+
elif (
1101+
self._fit_method == "brute" and self.metric == "precomputed" and issparse(X)
1102+
):
10781103
results = _radius_neighbors_from_graph(
10791104
X, radius=radius, return_distance=return_distance
10801105
)
10811106

10821107
elif self._fit_method == "brute":
1108+
# TODO: should no longer be needed once we have Cython-optimized
1109+
# implementation for radius queries, with support for sparse and/or
1110+
# float32 inputs.
1111+
10831112
# for efficiency, use squared euclidean distances
10841113
if self.effective_metric_ == "euclidean":
10851114
radius *= radius
@@ -1113,10 +1142,6 @@ class from an array representing our data set and ask who's
11131142
results = _to_object_array(neigh_ind_list)
11141143

11151144
if sort_results:
1116-
if not return_distance:
1117-
raise ValueError(
1118-
"return_distance must be True if sort_results is True."
1119-
)
11201145
for ii in range(len(neigh_dist)):
11211146
order = np.argsort(neigh_dist[ii], kind="mergesort")
11221147
neigh_ind[ii] = neigh_ind[ii][order]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.