-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
fix: mps
device support in entropy
#29321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for the fix. I see 3 possible solutions for problem with sparse matrices.
Also, I will try to check each PR in colab more carefully. |
Yes, you are right! I think it could be nice to support multilabel. |
accuracy
and mps support for entropy
I've implemented something in the direction of option 3. I will work on Not sure if I should make 2 PRs, since the fixes for entropy and for accuracy are unrelated. 🤔 |
That would be great thanks. That would ease the review in case one of the changes is controversial. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accuracy
and mps support for entropy
mps
device support in entropy
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Created the 2nd PR #29336 for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once https://github.com/scikit-learn/scikit-learn/pull/29321/files#r1652036505 is addressed.
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @EdAbati
Reference Issues/PRs
As mentioned in #29300, we have some tests regarding the Array API that are failing in
main
.and
What does this implement/fix? Explain your changes.
The first commit fixes the issue with
mps
.I believe that the others were caused by #29269.
If I remember correctly, the Array API does not support sparse matrices, and therefore should not work with those metrics in the multilabel case. It seems that the code introduced in that PR only works for numpy and torch in cpu, or am I missing something?
Should we revert this change?
cc @ogrisel