From da64dc8da3738f88c78618566071874f77654e84 Mon Sep 17 00:00:00 2001 From: Tyler Reddy Date: Fri, 9 May 2025 16:49:18 -0600 Subject: [PATCH] MAINT: mutual info upstream KDTree * This patch aims to test the waters for reducing community duplication of effort/maintenance with KDTree. In particular, a few select use cases of the `sklearn` in-house `KDTree` by the mutual information infrastructure are replaced with upstream `KDTree` from SciPy. The full test suite appears to pass locally. --- sklearn/feature_selection/_mutual_info.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/feature_selection/_mutual_info.py b/sklearn/feature_selection/_mutual_info.py index aef9097879fca..40851e30cf3e8 100644 --- a/sklearn/feature_selection/_mutual_info.py +++ b/sklearn/feature_selection/_mutual_info.py @@ -5,10 +5,11 @@ import numpy as np from scipy.sparse import issparse +from scipy.spatial import KDTree from scipy.special import digamma from ..metrics.cluster import mutual_info_score -from ..neighbors import KDTree, NearestNeighbors +from ..neighbors import NearestNeighbors from ..preprocessing import scale from ..utils import check_random_state from ..utils._param_validation import Interval, StrOptions, validate_params @@ -62,12 +63,12 @@ def _compute_mi_cc(x, y, n_neighbors): # KDTree is explicitly fit to allow for the querying of number of # neighbors within a specified radius - kd = KDTree(x, metric="chebyshev") - nx = kd.query_radius(x, radius, count_only=True, return_distance=False) + kd = KDTree(x) + nx = kd.query_ball_point(x, radius, p=np.inf, return_length=True) nx = np.array(nx) - 1.0 - kd = KDTree(y, metric="chebyshev") - ny = kd.query_radius(y, radius, count_only=True, return_distance=False) + kd = KDTree(y) + ny = kd.query_ball_point(y, radius, p=np.inf, return_length=True) ny = np.array(ny) - 1.0 mi = ( @@ -140,7 +141,7 @@ def _compute_mi_cd(c, d, n_neighbors): radius = radius[mask] kd = KDTree(c) - m_all = kd.query_radius(c, radius, count_only=True, return_distance=False) + m_all = kd.query_ball_point(c, radius, return_length=True) m_all = np.array(m_all) mi = (