Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

WIP, MAINT: NeighborsBase KDTree upstream #31358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
Loading
from

Conversation

tylerjereddy
Copy link
Contributor

  • This is an extension of the concept in MAINT: mutual information using upstream KDTree #31347--here, part of the usage of in-house KDTree in NeighborsBase is replaced by its upstream version from SciPy. This is a much more challenging effort that clearly shows some substantial differences between the two KDTree APIs/methods, and the shims needed to address them. At the moment, there is still a small number of residual test failures (29 locally) in the full testsuite.

  • Some kind of API unification/equivalence of offerings seems likely to be needed for these kinds of replacements to be more sustainable (the shims added here were quite time consuming to figure out). Some of the test expectations may also be debatable for cases with i.e., degenerate input.

* This is an extension of the concept in scikit-learngh-31347--here,
part of the usage of in-house `KDTree` in `NeighborsBase`
is replaced by its upstream version from SciPy. This is a much
more challenging effort that clearly shows some substantial
differences between the two `KDTree` APIs/methods, and the shims
needed to address them. At the moment, there is still a small
number of residual test failures (29 locally) in the full testsuite.

* Some kind of API unification/equivalence of offerings seems likely
to be needed for these kinds of replacements to be more sustainable (the
shims added here were quite time consuming to figure out). Some of the
test expectations may also be debatable for cases with i.e., degenerate
input.
Copy link

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


ruff check

ruff detected issues. Please run ruff check --fix --output-format=full locally, fix the remaining issues, and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.11.7.


sklearn/neighbors/_base.py:1308:89: E501 Line too long (92 > 88)
     |
1306 |                 nn_vals.append(len(sub_arr))
1307 |             if return_distance:
1308 |                 dd, ii = self._sp_tree.query(X, k=max(nn_vals), distance_upper_bound=radius)
     |                                                                                         ^^^^ E501
1309 |                 dd_new = []
1310 |                 ii_new = []
     |

sklearn/neighbors/_base.py:1326:89: E501 Line too long (98 > 88)
     |
1324 |                 ii = ii_new
1325 |                 try:
1326 |                     chunked_results = [(np.asarray(ii, dtype=int), np.asarray(dd, dtype=X.dtype))]
     |                                                                                         ^^^^^^^^^^ E501
1327 |                 except ValueError:
1328 |                     chunked_results = [(np.asarray(ii, dtype=object), np.asarray(dd, dtype=object))]
     |

sklearn/neighbors/_base.py:1328:89: E501 Line too long (100 > 88)
     |
1326 |                     chunked_results = [(np.asarray(ii, dtype=int), np.asarray(dd, dtype=X.dtype))]
1327 |                 except ValueError:
1328 |                     chunked_results = [(np.asarray(ii, dtype=object), np.asarray(dd, dtype=object))]
     |                                                                                         ^^^^^^^^^^^^ E501
1329 |             else:
1330 |                 for idx, sub_ele in enumerate(chunked_results[0]):
     |

Found 3 errors.

ruff format

ruff detected issues. Please run ruff format locally and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.11.7.


--- sklearn/neighbors/_base.py
+++ sklearn/neighbors/_base.py
@@ -1305,7 +1305,9 @@
             for sub_arr in chunked_results[0]:
                 nn_vals.append(len(sub_arr))
             if return_distance:
-                dd, ii = self._sp_tree.query(X, k=max(nn_vals), distance_upper_bound=radius)
+                dd, ii = self._sp_tree.query(
+                    X, k=max(nn_vals), distance_upper_bound=radius
+                )
                 dd_new = []
                 ii_new = []
                 for i in range(len(dd)):
@@ -1323,9 +1325,13 @@
                 dd = dd_new
                 ii = ii_new
                 try:
-                    chunked_results = [(np.asarray(ii, dtype=int), np.asarray(dd, dtype=X.dtype))]
+                    chunked_results = [
+                        (np.asarray(ii, dtype=int), np.asarray(dd, dtype=X.dtype))
+                    ]
                 except ValueError:
-                    chunked_results = [(np.asarray(ii, dtype=object), np.asarray(dd, dtype=object))]
+                    chunked_results = [
+                        (np.asarray(ii, dtype=object), np.asarray(dd, dtype=object))
+                    ]
             else:
                 for idx, sub_ele in enumerate(chunked_results[0]):
                     chunked_results[0][idx] = np.sort(chunked_results[0][idx])

1 file would be reformatted, 917 files already formatted

Generated for commit: b29ab20. Link to the linter CI: here

self._sp_tree = spKDTree(
X,
self.leaf_size,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing _tree itself was much messier, so I tried to keep this scoped for prototyping.

for sub_arr in chunked_results[0]:
nn_vals.append(len(sub_arr))
if return_distance:
dd, ii = self._sp_tree.query(X, k=max(nn_vals), distance_upper_bound=radius)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needing to call query_ball_point above and query here is a demonstration of API differences in KDTree between our libraries causing awkwardness in substituted workflows.

if sort_results:
sort_inds = np.argsort(finite_indices)
sorted_inds = finite_indices[sort_inds]
sorted_dists = finite_dists[sort_inds]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different handling of ragged data structures and inf/invalid values also seems to be present, requiring these additional shims.

@tylerjereddy
Copy link
Contributor Author

Some of this work was done in person with @virchan today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant
Morty Proxy This is a proxified and sanitized view of the page, visit original site.