Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

ENH: Make brier_score_loss Array API compatible #31191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
Loading
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions 1 doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ base estimator also does:
Metrics
-------

- :func:`sklearn.metrics.brier_score_loss` (only the binary class case is supported)
- :func:`sklearn.metrics.cluster.entropy`
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.d2_tweedie_score`
Expand Down
2 changes: 2 additions & 0 deletions 2 doc/whats_new/upcoming_changes/array-api/31191.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.brier_score_loss` now support Array API compatible inputs for the binary class case.
By :user:`Thomas Li <lithomas1>`
40 changes: 28 additions & 12 deletions 40 sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,11 @@ def _validate_multiclass_probabilistic_prediction(
y_prob = check_array(
y_prob, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
)
xp, _ = get_namespace(y_true, y_prob)

if y_prob.max() > 1:
if xp.max(y_prob) > 1:
raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}")
if y_prob.min() < 0:
if xp.min(y_prob) < 0:
raise ValueError(f"y_prob contains values lower than 0: {y_prob.min()}")

check_consistent_length(y_prob, y_true, sample_weight)
Expand Down Expand Up @@ -3428,6 +3429,7 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos
assert_all_finite(y_prob)

check_consistent_length(y_prob, y_true, sample_weight)
xp, _, device = get_namespace_and_device(y_true, y_prob, sample_weight)

y_type = type_of_target(y_true, input_name="y_true")
if y_type != "binary":
Expand All @@ -3436,16 +3438,16 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos
"binary according to the shape of y_prob."
)

if y_prob.max() > 1:
if xp.max(y_prob) > 1:
raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}")
if y_prob.min() < 0:
if xp.min(y_prob) < 0:
raise ValueError(f"y_prob contains values less than 0: {y_prob.min()}")

# check that pos_label is consistent with y_true
try:
pos_label = _check_pos_label_consistency(pos_label, y_true)
except ValueError:
classes = np.unique(y_true)
classes = xp.unique_values(y_true)
if classes.dtype.kind not in ("O", "U", "S"):
# for backward compatibility, if classes are not string then
# `pos_label` will correspond to the greater label
Expand All @@ -3454,9 +3456,9 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos
raise

# convert (n_samples,) to (n_samples, 2) shape
y_true = np.array(y_true == pos_label, int)
transformed_labels = np.column_stack((1 - y_true, y_true))
y_prob = np.column_stack((1 - y_prob, y_prob))
y_true = xp.asarray(y_true == pos_label, dtype=xp.int64, device=device)
transformed_labels = xp.stack((1 - y_true, y_true), axis=1)
y_prob = xp.stack((1 - y_prob, y_prob), axis=1)

return transformed_labels, y_prob

Expand Down Expand Up @@ -3589,8 +3591,17 @@ def brier_score_loss(
... )
0.146...
"""
xp, _, device = get_namespace_and_device(
y_true,
y_proba,
sample_weight,
)
y_proba = check_array(
y_proba, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
y_proba,
ensure_2d=False,
dtype=tuple(
xp.__array_namespace_info__().dtypes(kind="real floating").values()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this change aside from just replacing np with xp in the floatdata types? Is there some other float dtype that we want to support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this since some libraries like PyTorch MPS don't support xp.float32.
(Also float16 is not in the array API standard. Should we make a special exception for np.float16?)

Maybe it would be good to put this in a helper in the array API utils module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, we used _find_matching_floating_dtype for this use case. We could update that utility to leverage __array_namespace_info__ as you did here.

We could also improve check_array to accept dtype="floating" and do device/namespace specific conversion when provided with integer inputs.

),
)

if y_proba.ndim == 1 or y_proba.shape[1] == 1:
Expand All @@ -3601,9 +3612,14 @@ def brier_score_loss(
transformed_labels, y_proba = _validate_multiclass_probabilistic_prediction(
y_true, y_proba, sample_weight, labels
)

brier_score = np.average(
np.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight
transformed_labels = xp.asarray(transformed_labels, device=device)
y_proba = xp.asarray(y_proba, device=device)
Comment on lines +3615 to +3616
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be on the device already. I think y_proba might be shifted to cpu because of the check_array function but assuming that y_true and y_prob are on the expected device transformed_labels should be on the device as well. Or is this just handling for the array-api-strict?


# If transformed_labels is integer array, cast it to the floating dtype of
# y_proba
transformed_labels = xp.astype(transformed_labels, y_proba.dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are again moving it on the device?

brier_score = _average(
xp.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight
)

if scale_by_half == "auto":
Expand Down
1 change: 1 addition & 0 deletions 1 sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,6 +2229,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric_multioutput,
],
sigmoid_kernel: [check_array_api_metric_pairwise],
brier_score_loss: [check_array_api_binary_classification_metric],
}


Expand Down
41 changes: 26 additions & 15 deletions 41 sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@

from .. import get_config as _get_config
from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning
from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
from ..utils._array_api import (
_asarray_with_order,
_convert_to_numpy,
_is_numpy_namespace,
_max_precision_float_dtype,
get_namespace,
get_namespace_and_device,
)
from ..utils.deprecation import _deprecate_force_all_finite
from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype
from ._isfinite import FiniteStatus, cy_isfinite
Expand Down Expand Up @@ -2148,9 +2155,10 @@ def _check_sample_weight(
dtype of the validated `sample_weight`.
If None, and `sample_weight` is an array:

- If `sample_weight.dtype` is one of `{np.float64, np.float32}`,
- If `sample_weight.dtype` is one of `{xp.float64, xp.float32}`,
then the dtype is preserved.
- Else the output has NumPy's default dtype: `np.float64`.
- Otherwise, the output has the highest precision floating point dtype
supported by the array namespace/device of the input arrays.

If `dtype` is not `{np.float32, np.float64, None}`, then output will
be `np.float64`.
Expand All @@ -2169,17 +2177,18 @@ def _check_sample_weight(
Validated sample weight. It is guaranteed to be "C" contiguous.
"""
n_samples = _num_samples(X)
xp, _, device = get_namespace_and_device(X, sample_weight)

if dtype is not None and dtype not in [np.float32, np.float64]:
dtype = np.float64
if dtype is not None and dtype not in [xp.float32, xp.float64]:
dtype = _max_precision_float_dtype(xp, device)

if sample_weight is None:
sample_weight = np.ones(n_samples, dtype=dtype)
sample_weight = xp.ones(n_samples, dtype=dtype)
elif isinstance(sample_weight, numbers.Number):
sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
else:
if dtype is None:
dtype = [np.float64, np.float32]
dtype = [xp.float64, xp.float32]
sample_weight = check_array(
sample_weight,
accept_sparse=False,
Expand Down Expand Up @@ -2629,14 +2638,16 @@ def _check_pos_label_consistency(pos_label, y_true):
# when elements in the two arrays are not comparable.
if pos_label is None:
# Compute classes only if pos_label is not specified:
classes = np.unique(y_true)
if classes.dtype.kind in "OUS" or not (
np.array_equal(classes, [0, 1])
or np.array_equal(classes, [-1, 1])
or np.array_equal(classes, [0])
or np.array_equal(classes, [-1])
or np.array_equal(classes, [1])
xp, _, device = get_namespace_and_device(y_true)
classes = xp.unique_values(y_true)
if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I seem to recall seeing a similar kind of change in another PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the part from #30878.

xp.all(classes == xp.asarray([0, 1], device=device))
or xp.all(classes == xp.asarray([-1, 1], device=device))
or xp.all(classes == xp.asarray([0], device=device))
or xp.all(classes == xp.asarray([-1], device=device))
or xp.all(classes == xp.asarray([1], device=device))
):
classes = _convert_to_numpy(classes, xp=xp)
classes_repr = ", ".join([repr(c) for c in classes.tolist()])
raise ValueError(
f"y_true takes value in {{{classes_repr}}} and pos_label is not "
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.