Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

ENH: Make brier_score_loss Array API compatible #31191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
Loading
from

Conversation

lithomas1
Copy link
Contributor

@lithomas1 lithomas1 commented Apr 13, 2025

Reference Issues/PRs

xref #26024
Depends on #30878

What does this implement/fix? Explain your changes.

Makes brier_score_loss Array API compatible.

Any other comments?

@lithomas1 lithomas1 changed the title Brier array api ENH: Make brier_score_loss Array API compatible Apr 13, 2025
Copy link

github-actions bot commented Apr 13, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 4020016. Link to the linter CI: here

@lithomas1 lithomas1 marked this pull request as ready for review April 13, 2025 21:57
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lithomas1

y_proba,
ensure_2d=False,
dtype=tuple(
xp.__array_namespace_info__().dtypes(kind="real floating").values()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this change aside from just replacing np with xp in the floatdata types? Is there some other float dtype that we want to support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this since some libraries like PyTorch MPS don't support xp.float32.
(Also float16 is not in the array API standard. Should we make a special exception for np.float16?)

Maybe it would be good to put this in a helper in the array API utils module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, we used _find_matching_floating_dtype for this use case. We could update that utility to leverage __array_namespace_info__ as you did here.

We could also improve check_array to accept dtype="floating" and do device/namespace specific conversion when provided with integer inputs.

Comment on lines +3615 to +3616
transformed_labels = xp.asarray(transformed_labels, device=device)
y_proba = xp.asarray(y_proba, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be on the device already. I think y_proba might be shifted to cpu because of the check_array function but assuming that y_true and y_prob are on the expected device transformed_labels should be on the device as well. Or is this just handling for the array-api-strict?


# If transformed_labels is integer array, cast it to the floating dtype of
# y_proba
transformed_labels = xp.astype(transformed_labels, y_proba.dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are again moving it on the device?

or np.array_equal(classes, [1])
xp, _, device = get_namespace_and_device(y_true)
classes = xp.unique_values(y_true)
if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I seem to recall seeing a similar kind of change in another PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part from #30878.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.