-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH: Make brier_score_loss Array API compatible #31191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @lithomas1
y_proba, | ||
ensure_2d=False, | ||
dtype=tuple( | ||
xp.__array_namespace_info__().dtypes(kind="real floating").values() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this change aside from just replacing np with xp in the floatdata types? Is there some other float dtype that we want to support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this since some libraries like PyTorch MPS don't support xp.float32.
(Also float16 is not in the array API standard. Should we make a special exception for np.float16?)
Maybe it would be good to put this in a helper in the array API utils module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far, we used _find_matching_floating_dtype
for this use case. We could update that utility to leverage __array_namespace_info__
as you did here.
We could also improve check_array
to accept dtype="floating"
and do device/namespace specific conversion when provided with integer inputs.
transformed_labels = xp.asarray(transformed_labels, device=device) | ||
y_proba = xp.asarray(y_proba, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't these be on the device already. I think y_proba might be shifted to cpu because of the check_array function but assuming that y_true and y_prob are on the expected device transformed_labels should be on the device as well. Or is this just handling for the array-api-strict?
|
||
# If transformed_labels is integer array, cast it to the floating dtype of | ||
# y_proba | ||
transformed_labels = xp.astype(transformed_labels, y_proba.dtype, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we are again moving it on the device?
or np.array_equal(classes, [1]) | ||
xp, _, device = get_namespace_and_device(y_true) | ||
classes = xp.unique_values(y_true) | ||
if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I seem to recall seeing a similar kind of change in another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the part from #30878.
Reference Issues/PRs
xref #26024
Depends on #30878
What does this implement/fix? Explain your changes.
Makes brier_score_loss Array API compatible.
Any other comments?