Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

MAINT Introduce dispatchers for PairwiseDistancesReductions #23515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

jjerphan
Copy link
Member

@jjerphan jjerphan commented Jun 1, 2022

Reference Issues/PRs

Logic extracted from #22590 mainly to discuss design independently from 32bit support.

What does this implement/fix? Explain your changes.

This PR introduces Python interfaces.

Those dispatcher are meant to be used in the Python code, decoupling the actual implementation from the Python code. This allows changing all the private implementation while maintaining a contract for the Python callers.

Each dispatcher extending the base PairwiseDistancesReduction dispatcher must implement the compute classmethod.

Under the hood, such a function must only define the logic to dispatch at runtime to the correct dtype-specialized PairwiseDistancesReduction implementation based on the dtype of X and of Y.

This refactoring will ease other dtype support such as float32 support and implementations are mostly left unchanged (a few empty callback have been introduced based on some changes made to GEMMTermComputer64).

Remarks / points to discuss

  • GEMMTermComputer has been suffixed, moved upward in the file and extended a bit for consistency and to introduce new dtype-specific implementation more easily.

  • Regarding naming: currently, interfaces took the name of previous implementations and new implementations are solely suffixed with 64. I think this naming could be improved.

  • I think it's time to create a private submodule consisting of several files for PairwiseDistancesReductions.

Those interfaces are meant to be used in the Python code, decoupling the
actual implementation from the Python code. This allows changing all the
private implementation while maintaining a contract for the Python
callers.

Each interface extending the base `PairwiseDistancesReduction` interface
must implement the :meth:`compute` classmethod.

Under the hood, such a function must only define the logic to dispatch
at runtime to the correct dtype-specialized `PairwiseDistancesReduction`
implementation based on the dtype of X and of Y.

This refactoring will ease other dtype support such as float32 support.
@jjerphan jjerphan marked this pull request as ready for review June 1, 2022 16:25
@ogrisel
Copy link
Member

ogrisel commented Jun 2, 2022

I think it's time to create a private submodule consisting of several files for PairwiseDistancesReductions.

I am not opposed to that.

There is a bunch of tests that fail. I did not investigate. Please feel free to ping me for a review once those are fixed. I don't know if fixing them has an impact on the design of the refactoring or not.

@jjerphan jjerphan changed the title MAINT Introduce interfaces for PairwiseDistancesReductions MAINT Introduce dispatcher for PairwiseDistancesReductions Jun 8, 2022
@jjerphan jjerphan changed the title MAINT Introduce dispatcher for PairwiseDistancesReductions MAINT Introduce dispatchers for PairwiseDistancesReductions Jun 8, 2022
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with making a submodule in which the dispatchers are separated from the implementations.
As discussed irl, the dispatchers can probably be made standard python classes now (the implementations doesn't need to inherit from PairwiseDistanceReduction.

sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
jjerphan and others added 2 commits June 9, 2022 15:12
… with some ASCII art. 🎨

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
@jjerphan
Copy link
Member Author

As discussed irl, the dispatchers can probably be made standard python classes now (the implementations doesn't need to inherit from PairwiseDistanceReduction.

Done in cd49d8a.

I've done the full private module refactoring (see jjerphan#14 which targets the branch of this PR).

As this refactoring is quite large, I would submit it in another PR. @ogrisel think it's fine to have it integrated in this PR.

Let me know whether I should introduce this refactoring in this PR or submit it in another PR.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class hierarchy design and the updated documentation look (very) good to me.

The top level dispatchers could be Python classes instead of Cython classes but we can change that later if we prefer. No strong opinion.

EDIT: this point was addressed in jjerphan#14

I think we can merge this PR to main first (before doing the multi-file split of the module) if you or @jeremiedbb prefer. I don't mind either options.

sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@jjerphan
Copy link
Member Author

jjerphan commented Jun 10, 2022

The top level dispatchers could be Python classes instead of Cython classes but we can change that later if we prefer. No strong opinion.

Note that this was addressed in this PR via cd49d8a (which is also present in the follow-up PR, jjerphan#14).

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm overall happy with this type of refactor. Moving the Cython data structures down to the implementation, allowing for the dispatcher to be in regular Python makes sense to me.

sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
@@ -764,11 +1160,11 @@ cdef class PairwiseDistancesArgKmin(PairwiseDistancesReduction):
Y_start + j,
)

@final
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does removing the @final here lead to any performance regressions for metrics that is not euclidean ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. This needs to be assessed.

IIRC from my experiments, optimisations provided by @final mainly comes on classes that do not make use of polymorphism (but we do use polymorphism here).

It's more a matter of implementation explicitness: the usage of @final on methods indicates that this method is not overwritten in any subclass (similarly the usage of @final on classes indicates that those classes aren't extended).

In overall, I think we must check for any performance regression before merging this PR.

sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
@jeremiedbb
Copy link
Member

I think we can merge this PR to main first (before doing the multi-file split of the module) if you or @jeremiedbb prefer. I don't mind either options.

I'm ok with merging this PR first and follow up with reorganizing into a submodule.

Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
jjerphan added a commit to jjerphan/scikit-learn that referenced this pull request Jun 15, 2022
@jjerphan
Copy link
Member Author

jjerphan commented Jun 15, 2022

No regression on a machine with 256 threads for the FastEuclideanPairwiseDistancesArgKmin specialisation, via jjerphan@31fb45e using:

asv continuous -b kneigh -e main maint/pairwise-distances-reductions-interfaces                                                                                                 

· Creating environments
· Discovering benchmarks
·· Uninstalling from conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl.
·· Installing e6f9c9aa <maint/pairwise-distances-reductions-interfaces> into conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl.                                                                         
· Running 2 total benchmarks (2 commits * 1 environments * 1 benchmarks)
[  0.00%] · For scikit-learn commit e6f9c9aa <maint/pairwise-distances-reductions-interfaces> (round 1/1):
[  0.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[ 50.00%] ··· bruteforce.BruteForceNearestNeighborsBenchmark.time_kneighbors                                                                                                                              
                                                 ok
[ 50.00%] ··· ========= ======== ============ ============ =========== ============ ==============
              --                                                    k_radius
              ------------------------------- ----------------------------------------------------
               n_train   n_test   n_features     (1, 1)      (10, 10)   (100, 100)   (1000, 1000)
              ========= ======== ============ ============ =========== ============ ==============
                100000   100000       50       5.45±0.08s   5.76±0.7s   8.89±0.6s     22.0±0.2s
                100000   100000      100       8.13±0.3s     8.05±1s    11.3±0.2s     21.4±0.3s
                100000   100000      500       18.3±0.6s    21.3±0.3s   24.1±0.8s     31.9±0.2s
              ========= ======== ============ ============ =========== ============ ==============

[ 50.00%] · For scikit-learn commit fc9be34b <main> (round 1/1):
[ 50.00%] ·· Building for conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl..
[ 50.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[100.00%] ··· bruteforce.BruteForceNearestNeighborsBenchmark.time_kneighbors                                                                                                                              
                                                 ok
[100.00%] ··· ========= ======== ============ =========== =========== ============ ==============
              --                                                    k_radius
              ------------------------------- ---------------------------------------------------
               n_train   n_test   n_features     (1, 1)     (10, 10)   (100, 100)   (1000, 1000)
              ========= ======== ============ =========== =========== ============ ==============
                100000   100000       50        9.36±1s    5.59±0.2s   9.26±0.2s     20.2±0.1s
                100000   100000      100       7.84±0.5s   7.42±0.5s   10.5±0.2s     23.2±0.7s
                100000   100000      500       18.8±0.2s   19.2±0.6s    22.2±1s      31.1±0.2s
              ========= ======== ============ =========== =========== ============ ==============


BENCHMARKS NOT SIGNIFICANTLY CHANGED.

jjerphan added a commit to jjerphan/scikit-learn that referenced this pull request Jun 15, 2022
@jjerphan
Copy link
Member Author

jjerphan commented Jun 15, 2022

No regression on a machine with 256 threads for the PairwiseDistancesArgKmin (on 'manhattan'), via jjerphan@cd1e6ec using:

asv continuous -b kneigh -e main maint/pairwise-distances-reductions-interfaces                                                                           
· Creating environments
· Discovering benchmarks
· Running 2 total benchmarks (2 commits * 1 environments * 1 benchmarks)
[  0.00%] · For scikit-learn commit e6f9c9aa <maint/pairwise-distances-reductions-interfaces> (round 1/1):                                                                          
[  0.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[ 50.00%] ··· bruteforce.BruteForceNearestNeighborsBenchmark.time_kneighbors                                                                                                       ok
[ 50.00%] ··· ========= ======== ============ ============ ============ ============ ==============                                                                                 
              --                                                     k_radius                                                                                                       
              ------------------------------- -----------------------------------------------------                                                                                 
               n_train   n_test   n_features     (1, 1)      (10, 10)    (100, 100)   (1000, 1000)                                                                                  
              ========= ======== ============ ============ ============ ============ ==============                                                                                 
                10000    10000        50       2.21±0.02s   2.30±0.03s   2.26±0.03s    2.48±0.02s                                                                                   
                10000    10000       100       2.26±0.01s   2.27±0.01s   2.34±0.02s    2.68±0.02s                                                                                   
                10000    10000       500       3.42±0.01s   3.47±0.01s   3.55±0.01s    3.86±0.02s                                                                                   
              ========= ======== ============ ============ ============ ============ ==============                                                                                 

[ 50.00%] · For scikit-learn commit fc9be34b <main> (round 1/1):
[ 50.00%] ·· Building for conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl...
[ 50.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[100.00%] ··· bruteforce.BruteForceNearestNeighborsBenchmark.time_kneighbors                                                                                                       ok
[100.00%] ··· ========= ======== ============ ============ ============ ============ ==============
              --                                                     k_radius                      
              ------------------------------- -----------------------------------------------------
               n_train   n_test   n_features     (1, 1)      (10, 10)    (100, 100)   (1000, 1000) 
              ========= ======== ============ ============ ============ ============ ==============
                10000    10000        50       2.25±0.02s   2.25±0.02s   2.26±0.03s    2.48±0.02s  
                10000    10000       100       2.25±0.02s   2.26±0.01s   2.34±0.01s    2.68±0.02s  
                10000    10000       500       3.42±0.02s   3.51±0.03s   3.58±0.02s    3.85±0.02s  
              ========= ======== ============ ============ ============ ============ ==============


BENCHMARKS NOT SIGNIFICANTLY CHANGED.

@ogrisel
Copy link
Member

ogrisel commented Jun 15, 2022

Thanks for checking the performance. +1 again for merge on my side.

@thomasjpfan @jeremiedbb anything else?

@ogrisel
Copy link
Member

ogrisel commented Jun 17, 2022

I removed the blocker label as it's meant to only be used to label bugs that should be fixed before making a specific release.

Here is this not a bug (it's a performance improvement) and its not blocking any particular release. It's blocking progress on follow-up PRs but we don't have a label for that.

sklearn/metrics/_pairwise_distances_reduction.pyx Outdated Show resolved Hide resolved
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's merge

@jeremiedbb jeremiedbb merged commit 74cda4b into scikit-learn:main Jun 22, 2022
@jeremiedbb
Copy link
Member

Thanks @jjerphan. I guess you can now make a PR targeting main out of jjerphan#14

@jjerphan jjerphan deleted the maint/pairwise-distances-reductions-interfaces branch June 22, 2022 15:22
@jjerphan
Copy link
Member Author

Thanks @jjerphan. I guess you can now make a PR targeting main out of jjerphan#14

Thanks. The follow-up is reviewable: #23724

ogrisel added a commit to ogrisel/scikit-learn that referenced this pull request Jul 11, 2022
…-learn#23515)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.