Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

ENH: Make brier_score_loss Array API compatible #31191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
Loading
from

Conversation

lithomas1
Copy link
Contributor

@lithomas1 lithomas1 commented Apr 13, 2025

Reference Issues/PRs

xref #26024
Depends on #30878

What does this implement/fix? Explain your changes.

Makes brier_score_loss Array API compatible.

Any other comments?

@lithomas1 lithomas1 changed the title Brier array api ENH: Make brier_score_loss Array API compatible Apr 13, 2025
Copy link

github-actions bot commented Apr 13, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 4020016. Link to the linter CI: here

@lithomas1 lithomas1 marked this pull request as ready for review April 13, 2025 21:57
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lithomas1

y_proba,
ensure_2d=False,
dtype=tuple(
xp.__array_namespace_info__().dtypes(kind="real floating").values()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this change aside from just replacing np with xp in the floatdata types? Is there some other float dtype that we want to support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this since some libraries like PyTorch MPS don't support xp.float32.
(Also float16 is not in the array API standard. Should we make a special exception for np.float16?)

Maybe it would be good to put this in a helper in the array API utils module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, we used _find_matching_floating_dtype for this use case. We could update that utility to leverage __array_namespace_info__ as you did here.

We could also improve check_array to accept dtype="floating" and do device/namespace specific conversion when provided with integer inputs.

Comment on lines +3615 to +3616
transformed_labels = xp.asarray(transformed_labels, device=device)
y_proba = xp.asarray(y_proba, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be on the device already. I think y_proba might be shifted to cpu because of the check_array function but assuming that y_true and y_prob are on the expected device transformed_labels should be on the device as well. Or is this just handling for the array-api-strict?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y_proba might be shifted to cpu because of the check_array

In which case could y_proba be shifted to cpu?

If it is possible that check_array alters the device, we probably need to do get_namespace_and_device before any check_arrays - i.e., in _validate_binary_probabilistic_prediction we do a column_or_1d first, which does check_array inside. I wonder if it would be good practice to just always do get_namespace_and_device first?

Or is this just handling for the array-api-strict?

Question, why would it be needed for array-api-strict?


# If transformed_labels is integer array, cast it to the floating dtype of
# y_proba
transformed_labels = xp.astype(transformed_labels, y_proba.dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are again moving it on the device?

or np.array_equal(classes, [1])
xp, _, device = get_namespace_and_device(y_true)
classes = xp.unique_values(y_true)
if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I seem to recall seeing a similar kind of change in another PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part from #30878.

Copy link
Member

@lucyleeow lucyleeow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but had some comments/questions. Thanks!

@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.brier_score_loss` now support Array API compatible inputs for the binary class case.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- :func:`sklearn.metrics.brier_score_loss` now support Array API compatible inputs for the binary class case.
- :func:`sklearn.metrics.brier_score_loss` now supports Array API compatible inputs for the binary class case.

nit

Comment on lines -198 to +201
if y_prob.max() > 1:
if xp.max(y_prob) > 1:
raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}")
if y_prob.min() < 0:
if xp.min(y_prob) < 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily for this PR but this code seems to be repeated x4 in this module, maybe we could refactor it out?

Comment on lines +3615 to +3616
transformed_labels = xp.asarray(transformed_labels, device=device)
y_proba = xp.asarray(y_proba, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y_proba might be shifted to cpu because of the check_array

In which case could y_proba be shifted to cpu?

If it is possible that check_array alters the device, we probably need to do get_namespace_and_device before any check_arrays - i.e., in _validate_binary_probabilistic_prediction we do a column_or_1d first, which does check_array inside. I wonder if it would be good practice to just always do get_namespace_and_device first?

Or is this just handling for the array-api-strict?

Question, why would it be needed for array-api-strict?

Comment on lines -2151 to +2158
- If `sample_weight.dtype` is one of `{np.float64, np.float32}`,
- If `sample_weight.dtype` is one of `{xp.float64, xp.float32}`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not changing types in public functions, I wonder if we should keep private ones as is too, for consistency?

@@ -2169,17 +2177,18 @@ def _check_sample_weight(
Validated sample weight. It is guaranteed to be "C" contiguous.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking, is this the same changes as made in #30878? Since that one as the additional tests, maybe it should be merged first?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.