Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

WIP, MAINT: NeighborsBase KDTree upstream #31358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion 54 sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
from joblib import effective_n_jobs
from scipy.sparse import csr_matrix, issparse
from scipy.spatial import KDTree as spKDTree

from ..base import BaseEstimator, MultiOutputMixin, is_classifier
from ..exceptions import DataConversionWarning, EfficiencyWarning
Expand Down Expand Up @@ -694,6 +695,10 @@ def _fit(self, X, y=None):
"try algorithm='ball_tree' "
"or algorithm='brute' instead."
)
self._sp_tree = spKDTree(
X,
self.leaf_size,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing _tree itself was much messier, so I tried to keep this scoped for prototyping.

self._tree = KDTree(
X,
self.leaf_size,
Expand Down Expand Up @@ -1265,7 +1270,7 @@ class from an array representing our data set and ask who's
neigh_dist[ii] = neigh_dist[ii][order]
results = neigh_dist, neigh_ind

elif self._fit_method in ["ball_tree", "kd_tree"]:
elif self._fit_method == "ball_tree":
if issparse(X):
raise ValueError(
"%s does not work with sparse matrices. Densify the data, "
Expand All @@ -1278,6 +1283,53 @@ class from an array representing our data set and ask who's
delayed_query(X[s], radius, return_distance, sort_results=sort_results)
for s in gen_even_slices(X.shape[0], n_jobs)
)
if return_distance:
neigh_ind, neigh_dist = tuple(zip(*chunked_results))
results = np.hstack(neigh_dist), np.hstack(neigh_ind)
else:
results = np.hstack(chunked_results)
elif self._fit_method == "kd_tree":
if issparse(X):
raise ValueError(
"%s does not work with sparse matrices. Densify the data, "
"or set algorithm='brute'" % self._fit_method
)

n_jobs = effective_n_jobs(self.n_jobs)
delayed_query = delayed(self._sp_tree.query_ball_point)
chunked_results = Parallel(n_jobs, prefer="threads")(
delayed_query(X[s], radius, return_sorted=sort_results)
for s in gen_even_slices(X.shape[0], n_jobs)
)
nn_vals = []
for sub_arr in chunked_results[0]:
nn_vals.append(len(sub_arr))
if return_distance:
dd, ii = self._sp_tree.query(X, k=max(nn_vals), distance_upper_bound=radius)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needing to call query_ball_point above and query here is a demonstration of API differences in KDTree between our libraries causing awkwardness in substituted workflows.

dd_new = []
ii_new = []
for i in range(len(dd)):
finite_indices = ii[i][np.isfinite(dd[i])]
finite_dists = dd[i][np.isfinite(dd[i])]
if sort_results:
sort_inds = np.argsort(finite_indices)
sorted_inds = finite_indices[sort_inds]
sorted_dists = finite_dists[sort_inds]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different handling of ragged data structures and inf/invalid values also seems to be present, requiring these additional shims.

else:
sorted_inds = finite_indices
sorted_dists = finite_dists
dd_new.append(sorted_dists)
ii_new.append(sorted_inds)
dd = dd_new
ii = ii_new
try:
chunked_results = [(np.asarray(ii, dtype=int), np.asarray(dd, dtype=X.dtype))]
except ValueError:
chunked_results = [(np.asarray(ii, dtype=object), np.asarray(dd, dtype=object))]
else:
for idx, sub_ele in enumerate(chunked_results[0]):
chunked_results[0][idx] = np.sort(chunked_results[0][idx])

if return_distance:
neigh_ind, neigh_dist = tuple(zip(*chunked_results))
results = np.hstack(neigh_dist), np.hstack(neigh_ind)
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.