Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d30b4fc

Browse filesBrowse files
committed
MAINT: mutual info upstream KDTree
* This patch aims to test the waters for reducing community duplication of effort/maintenance with KDTree. In particular, a few select use cases of the `sklearn` in-house `KDTree` by the mutual information infrastructure are replaced with upstream `KDTree` from SciPy. The full test suite appears to pass locally.
1 parent aa21650 commit d30b4fc
Copy full SHA for d30b4fc

File tree

1 file changed

+10
-6
lines changed
Filter options

1 file changed

+10
-6
lines changed

‎sklearn/feature_selection/_mutual_info.py

Copy file name to clipboardExpand all lines: sklearn/feature_selection/_mutual_info.py
+10-6
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55

66
import numpy as np
77
from scipy.sparse import issparse
8+
from scipy.spatial import KDTree
89
from scipy.special import digamma
910

1011
from ..metrics.cluster import mutual_info_score
11-
from ..neighbors import KDTree, NearestNeighbors
12+
from ..neighbors import NearestNeighbors
1213
from ..preprocessing import scale
1314
from ..utils import check_random_state
1415
from ..utils._param_validation import Interval, StrOptions, validate_params
@@ -62,12 +63,14 @@ def _compute_mi_cc(x, y, n_neighbors):
6263

6364
# KDTree is explicitly fit to allow for the querying of number of
6465
# neighbors within a specified radius
65-
kd = KDTree(x, metric="chebyshev")
66-
nx = kd.query_radius(x, radius, count_only=True, return_distance=False)
66+
kd = KDTree(x)
67+
nx = kd.query_ball_point(x, radius, p=np.inf)
68+
nx = [len(sub_list) for sub_list in nx]
6769
nx = np.array(nx) - 1.0
6870

69-
kd = KDTree(y, metric="chebyshev")
70-
ny = kd.query_radius(y, radius, count_only=True, return_distance=False)
71+
kd = KDTree(y)
72+
ny = kd.query_ball_point(y, radius, p=np.inf)
73+
ny = [len(sub_list) for sub_list in ny]
7174
ny = np.array(ny) - 1.0
7275

7376
mi = (
@@ -140,7 +143,8 @@ def _compute_mi_cd(c, d, n_neighbors):
140143
radius = radius[mask]
141144

142145
kd = KDTree(c)
143-
m_all = kd.query_radius(c, radius, count_only=True, return_distance=False)
146+
m_all = kd.query_ball_point(c, radius)
147+
m_all = [len(sub_list) for sub_list in m_all]
144148
m_all = np.array(m_all)
145149

146150
mi = (

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.