-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
WIP, MAINT: NeighborsBase KDTree
upstream
#31358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import numpy as np | ||
from joblib import effective_n_jobs | ||
from scipy.sparse import csr_matrix, issparse | ||
from scipy.spatial import KDTree as spKDTree | ||
|
||
from ..base import BaseEstimator, MultiOutputMixin, is_classifier | ||
from ..exceptions import DataConversionWarning, EfficiencyWarning | ||
|
@@ -694,6 +695,10 @@ def _fit(self, X, y=None): | |
"try algorithm='ball_tree' " | ||
"or algorithm='brute' instead." | ||
) | ||
self._sp_tree = spKDTree( | ||
X, | ||
self.leaf_size, | ||
) | ||
self._tree = KDTree( | ||
X, | ||
self.leaf_size, | ||
|
@@ -1265,7 +1270,7 @@ class from an array representing our data set and ask who's | |
neigh_dist[ii] = neigh_dist[ii][order] | ||
results = neigh_dist, neigh_ind | ||
|
||
elif self._fit_method in ["ball_tree", "kd_tree"]: | ||
elif self._fit_method == "ball_tree": | ||
if issparse(X): | ||
raise ValueError( | ||
"%s does not work with sparse matrices. Densify the data, " | ||
|
@@ -1278,6 +1283,53 @@ class from an array representing our data set and ask who's | |
delayed_query(X[s], radius, return_distance, sort_results=sort_results) | ||
for s in gen_even_slices(X.shape[0], n_jobs) | ||
) | ||
if return_distance: | ||
neigh_ind, neigh_dist = tuple(zip(*chunked_results)) | ||
results = np.hstack(neigh_dist), np.hstack(neigh_ind) | ||
else: | ||
results = np.hstack(chunked_results) | ||
elif self._fit_method == "kd_tree": | ||
if issparse(X): | ||
raise ValueError( | ||
"%s does not work with sparse matrices. Densify the data, " | ||
"or set algorithm='brute'" % self._fit_method | ||
) | ||
|
||
n_jobs = effective_n_jobs(self.n_jobs) | ||
delayed_query = delayed(self._sp_tree.query_ball_point) | ||
chunked_results = Parallel(n_jobs, prefer="threads")( | ||
delayed_query(X[s], radius, return_sorted=sort_results) | ||
for s in gen_even_slices(X.shape[0], n_jobs) | ||
) | ||
nn_vals = [] | ||
for sub_arr in chunked_results[0]: | ||
nn_vals.append(len(sub_arr)) | ||
if return_distance: | ||
dd, ii = self._sp_tree.query(X, k=max(nn_vals), distance_upper_bound=radius) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needing to call |
||
dd_new = [] | ||
ii_new = [] | ||
for i in range(len(dd)): | ||
finite_indices = ii[i][np.isfinite(dd[i])] | ||
finite_dists = dd[i][np.isfinite(dd[i])] | ||
if sort_results: | ||
sort_inds = np.argsort(finite_indices) | ||
sorted_inds = finite_indices[sort_inds] | ||
sorted_dists = finite_dists[sort_inds] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Different handling of ragged data structures and |
||
else: | ||
sorted_inds = finite_indices | ||
sorted_dists = finite_dists | ||
dd_new.append(sorted_dists) | ||
ii_new.append(sorted_inds) | ||
dd = dd_new | ||
ii = ii_new | ||
try: | ||
chunked_results = [(np.asarray(ii, dtype=int), np.asarray(dd, dtype=X.dtype))] | ||
except ValueError: | ||
chunked_results = [(np.asarray(ii, dtype=object), np.asarray(dd, dtype=object))] | ||
else: | ||
for idx, sub_ele in enumerate(chunked_results[0]): | ||
chunked_results[0][idx] = np.sort(chunked_results[0][idx]) | ||
|
||
if return_distance: | ||
neigh_ind, neigh_dist = tuple(zip(*chunked_results)) | ||
results = np.hstack(neigh_dist), np.hstack(neigh_ind) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing
_tree
itself was much messier, so I tried to keep this scoped for prototyping.