22
22
from ..base import BaseEstimator , MultiOutputMixin
23
23
from ..base import is_classifier
24
24
from ..metrics import pairwise_distances_chunked
25
+ from ..metrics ._pairwise_distances_reduction import PairwiseDistancesRadiusNeighborhood
25
26
from ..metrics .pairwise import PAIRWISE_DISTANCE_FUNCTIONS
26
27
from ..utils import (
27
28
check_array ,
@@ -1061,25 +1062,53 @@ class from an array representing our data set and ask who's
1061
1062
"""
1062
1063
check_is_fitted (self )
1063
1064
1064
- if X is not None :
1065
- query_is_train = False
1065
+ if sort_results and not return_distance :
1066
+ raise ValueError ("return_distance must be True if sort_results is True." )
1067
+
1068
+ query_is_train = X is None
1069
+ if query_is_train :
1070
+ X = self ._fit_X
1071
+ else :
1066
1072
if self .metric == "precomputed" :
1067
1073
X = _check_precomputed (X )
1068
1074
else :
1069
- X = self ._validate_data (X , accept_sparse = "csr" , reset = False )
1070
- else :
1071
- query_is_train = True
1072
- X = self ._fit_X
1075
+ X = self ._validate_data (X , accept_sparse = "csr" , reset = False , order = "C" )
1073
1076
1074
1077
if radius is None :
1075
1078
radius = self .radius
1076
1079
1077
- if self ._fit_method == "brute" and self .metric == "precomputed" and issparse (X ):
1080
+ use_pairwise_distances_reductions = (
1081
+ self ._fit_method == "brute"
1082
+ and PairwiseDistancesRadiusNeighborhood .is_usable_for (
1083
+ X if X is not None else self ._fit_X , self ._fit_X , self .effective_metric_
1084
+ )
1085
+ )
1086
+
1087
+ if use_pairwise_distances_reductions :
1088
+ results = PairwiseDistancesRadiusNeighborhood .compute (
1089
+ X = X ,
1090
+ Y = self ._fit_X ,
1091
+ radius = radius ,
1092
+ metric = self .effective_metric_ ,
1093
+ metric_kwargs = self .effective_metric_params_ ,
1094
+ n_threads = self .n_jobs ,
1095
+ strategy = "auto" ,
1096
+ return_distance = return_distance ,
1097
+ sort_results = sort_results ,
1098
+ )
1099
+
1100
+ elif (
1101
+ self ._fit_method == "brute" and self .metric == "precomputed" and issparse (X )
1102
+ ):
1078
1103
results = _radius_neighbors_from_graph (
1079
1104
X , radius = radius , return_distance = return_distance
1080
1105
)
1081
1106
1082
1107
elif self ._fit_method == "brute" :
1108
+ # TODO: should no longer be needed once we have Cython-optimized
1109
+ # implementation for radius queries, with support for sparse and/or
1110
+ # float32 inputs.
1111
+
1083
1112
# for efficiency, use squared euclidean distances
1084
1113
if self .effective_metric_ == "euclidean" :
1085
1114
radius *= radius
@@ -1113,10 +1142,6 @@ class from an array representing our data set and ask who's
1113
1142
results = _to_object_array (neigh_ind_list )
1114
1143
1115
1144
if sort_results :
1116
- if not return_distance :
1117
- raise ValueError (
1118
- "return_distance must be True if sort_results is True."
1119
- )
1120
1145
for ii in range (len (neigh_dist )):
1121
1146
order = np .argsort (neigh_dist [ii ], kind = "mergesort" )
1122
1147
neigh_ind [ii ] = neigh_ind [ii ][order ]
0 commit comments