From b29ab20cbf66d5201a590411b0e818b074d0194f Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Mon, 12 May 2025 16:51:39 -0700 Subject: [PATCH] WIP, MAINT: NeighborsBase `KDTree` upstream * This is an extension of the concept in gh-31347--here, part of the usage of in-house `KDTree` in `NeighborsBase` is replaced by its upstream version from SciPy. This is a much more challenging effort that clearly shows some substantial differences between the two `KDTree` APIs/methods, and the shims needed to address them. At the moment, there is still a small number of residual test failures (29 locally) in the full testsuite. * Some kind of API unification/equivalence of offerings seems likely to be needed for these kinds of replacements to be more sustainable (the shims added here were quite time consuming to figure out). Some of the test expectations may also be debatable for cases with i.e., degenerate input. --- sklearn/neighbors/_base.py | 54 +++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 767eee1358aa8..83b660f8eb331 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -13,6 +13,7 @@ import numpy as np from joblib import effective_n_jobs from scipy.sparse import csr_matrix, issparse +from scipy.spatial import KDTree as spKDTree from ..base import BaseEstimator, MultiOutputMixin, is_classifier from ..exceptions import DataConversionWarning, EfficiencyWarning @@ -694,6 +695,10 @@ def _fit(self, X, y=None): "try algorithm='ball_tree' " "or algorithm='brute' instead." ) + self._sp_tree = spKDTree( + X, + self.leaf_size, + ) self._tree = KDTree( X, self.leaf_size, @@ -1265,7 +1270,7 @@ class from an array representing our data set and ask who's neigh_dist[ii] = neigh_dist[ii][order] results = neigh_dist, neigh_ind - elif self._fit_method in ["ball_tree", "kd_tree"]: + elif self._fit_method == "ball_tree": if issparse(X): raise ValueError( "%s does not work with sparse matrices. Densify the data, " @@ -1278,6 +1283,53 @@ class from an array representing our data set and ask who's delayed_query(X[s], radius, return_distance, sort_results=sort_results) for s in gen_even_slices(X.shape[0], n_jobs) ) + if return_distance: + neigh_ind, neigh_dist = tuple(zip(*chunked_results)) + results = np.hstack(neigh_dist), np.hstack(neigh_ind) + else: + results = np.hstack(chunked_results) + elif self._fit_method == "kd_tree": + if issparse(X): + raise ValueError( + "%s does not work with sparse matrices. Densify the data, " + "or set algorithm='brute'" % self._fit_method + ) + + n_jobs = effective_n_jobs(self.n_jobs) + delayed_query = delayed(self._sp_tree.query_ball_point) + chunked_results = Parallel(n_jobs, prefer="threads")( + delayed_query(X[s], radius, return_sorted=sort_results) + for s in gen_even_slices(X.shape[0], n_jobs) + ) + nn_vals = [] + for sub_arr in chunked_results[0]: + nn_vals.append(len(sub_arr)) + if return_distance: + dd, ii = self._sp_tree.query(X, k=max(nn_vals), distance_upper_bound=radius) + dd_new = [] + ii_new = [] + for i in range(len(dd)): + finite_indices = ii[i][np.isfinite(dd[i])] + finite_dists = dd[i][np.isfinite(dd[i])] + if sort_results: + sort_inds = np.argsort(finite_indices) + sorted_inds = finite_indices[sort_inds] + sorted_dists = finite_dists[sort_inds] + else: + sorted_inds = finite_indices + sorted_dists = finite_dists + dd_new.append(sorted_dists) + ii_new.append(sorted_inds) + dd = dd_new + ii = ii_new + try: + chunked_results = [(np.asarray(ii, dtype=int), np.asarray(dd, dtype=X.dtype))] + except ValueError: + chunked_results = [(np.asarray(ii, dtype=object), np.asarray(dd, dtype=object))] + else: + for idx, sub_ele in enumerate(chunked_results[0]): + chunked_results[0][idx] = np.sort(chunked_results[0][idx]) + if return_distance: neigh_ind, neigh_dist = tuple(zip(*chunked_results)) results = np.hstack(neigh_dist), np.hstack(neigh_ind)