-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH: Make brier_score_loss Array API compatible #31191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
- :func:`sklearn.metrics.brier_score_loss` now support Array API compatible inputs for the binary class case. | ||
By :user:`Thomas Li <lithomas1>` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,10 +194,11 @@ def _validate_multiclass_probabilistic_prediction( | |
y_prob = check_array( | ||
y_prob, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] | ||
) | ||
xp, _ = get_namespace(y_true, y_prob) | ||
|
||
if y_prob.max() > 1: | ||
if xp.max(y_prob) > 1: | ||
raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") | ||
if y_prob.min() < 0: | ||
if xp.min(y_prob) < 0: | ||
Comment on lines
-198
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessarily for this PR but this code seems to be repeated x4 in this module, maybe we could refactor it out? |
||
raise ValueError(f"y_prob contains values lower than 0: {y_prob.min()}") | ||
|
||
check_consistent_length(y_prob, y_true, sample_weight) | ||
|
@@ -3428,6 +3429,7 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos | |
assert_all_finite(y_prob) | ||
|
||
check_consistent_length(y_prob, y_true, sample_weight) | ||
xp, _, device = get_namespace_and_device(y_true, y_prob, sample_weight) | ||
|
||
y_type = type_of_target(y_true, input_name="y_true") | ||
if y_type != "binary": | ||
|
@@ -3436,16 +3438,16 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos | |
"binary according to the shape of y_prob." | ||
) | ||
|
||
if y_prob.max() > 1: | ||
if xp.max(y_prob) > 1: | ||
raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") | ||
if y_prob.min() < 0: | ||
if xp.min(y_prob) < 0: | ||
raise ValueError(f"y_prob contains values less than 0: {y_prob.min()}") | ||
|
||
# check that pos_label is consistent with y_true | ||
try: | ||
pos_label = _check_pos_label_consistency(pos_label, y_true) | ||
except ValueError: | ||
classes = np.unique(y_true) | ||
classes = xp.unique_values(y_true) | ||
if classes.dtype.kind not in ("O", "U", "S"): | ||
# for backward compatibility, if classes are not string then | ||
# `pos_label` will correspond to the greater label | ||
|
@@ -3454,9 +3456,9 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos | |
raise | ||
|
||
# convert (n_samples,) to (n_samples, 2) shape | ||
y_true = np.array(y_true == pos_label, int) | ||
transformed_labels = np.column_stack((1 - y_true, y_true)) | ||
y_prob = np.column_stack((1 - y_prob, y_prob)) | ||
y_true = xp.asarray(y_true == pos_label, dtype=xp.int64, device=device) | ||
transformed_labels = xp.stack((1 - y_true, y_true), axis=1) | ||
y_prob = xp.stack((1 - y_prob, y_prob), axis=1) | ||
|
||
return transformed_labels, y_prob | ||
|
||
|
@@ -3589,8 +3591,17 @@ def brier_score_loss( | |
... ) | ||
0.146... | ||
""" | ||
xp, _, device = get_namespace_and_device( | ||
y_true, | ||
y_proba, | ||
sample_weight, | ||
) | ||
y_proba = check_array( | ||
y_proba, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] | ||
y_proba, | ||
ensure_2d=False, | ||
dtype=tuple( | ||
xp.__array_namespace_info__().dtypes(kind="real floating").values() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need this change aside from just replacing np with xp in the floatdata types? Is there some other float dtype that we want to support? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did this since some libraries like PyTorch MPS don't support xp.float32. Maybe it would be good to put this in a helper in the array API utils module? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So far, we used We could also improve |
||
), | ||
) | ||
|
||
if y_proba.ndim == 1 or y_proba.shape[1] == 1: | ||
|
@@ -3601,9 +3612,14 @@ def brier_score_loss( | |
transformed_labels, y_proba = _validate_multiclass_probabilistic_prediction( | ||
y_true, y_proba, sample_weight, labels | ||
) | ||
|
||
brier_score = np.average( | ||
np.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight | ||
transformed_labels = xp.asarray(transformed_labels, device=device) | ||
y_proba = xp.asarray(y_proba, device=device) | ||
Comment on lines
+3615
to
+3616
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't these be on the device already. I think y_proba might be shifted to cpu because of the check_array function but assuming that y_true and y_prob are on the expected device transformed_labels should be on the device as well. Or is this just handling for the array-api-strict? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In which case could If it is possible that
Question, why would it be needed for |
||
|
||
# If transformed_labels is integer array, cast it to the floating dtype of | ||
# y_proba | ||
transformed_labels = xp.astype(transformed_labels, y_proba.dtype, device=device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we are again moving it on the device? |
||
brier_score = _average( | ||
xp.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight | ||
) | ||
|
||
if scale_by_half == "auto": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,14 @@ | |
|
||
from .. import get_config as _get_config | ||
from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning | ||
from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace | ||
from ..utils._array_api import ( | ||
_asarray_with_order, | ||
_convert_to_numpy, | ||
_is_numpy_namespace, | ||
_max_precision_float_dtype, | ||
get_namespace, | ||
get_namespace_and_device, | ||
) | ||
from ..utils.deprecation import _deprecate_force_all_finite | ||
from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype | ||
from ._isfinite import FiniteStatus, cy_isfinite | ||
|
@@ -2148,9 +2155,10 @@ def _check_sample_weight( | |
dtype of the validated `sample_weight`. | ||
If None, and `sample_weight` is an array: | ||
|
||
- If `sample_weight.dtype` is one of `{np.float64, np.float32}`, | ||
- If `sample_weight.dtype` is one of `{xp.float64, xp.float32}`, | ||
Comment on lines
-2151
to
+2158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are not changing types in public functions, I wonder if we should keep private ones as is too, for consistency? |
||
then the dtype is preserved. | ||
- Else the output has NumPy's default dtype: `np.float64`. | ||
- Otherwise, the output has the highest precision floating point dtype | ||
supported by the array namespace/device of the input arrays. | ||
|
||
If `dtype` is not `{np.float32, np.float64, None}`, then output will | ||
be `np.float64`. | ||
|
@@ -2169,17 +2177,18 @@ def _check_sample_weight( | |
Validated sample weight. It is guaranteed to be "C" contiguous. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checking, is this the same changes as made in #30878? Since that one as the additional tests, maybe it should be merged first? |
||
""" | ||
n_samples = _num_samples(X) | ||
xp, _, device = get_namespace_and_device(X, sample_weight) | ||
|
||
if dtype is not None and dtype not in [np.float32, np.float64]: | ||
dtype = np.float64 | ||
if dtype is not None and dtype not in [xp.float32, xp.float64]: | ||
dtype = _max_precision_float_dtype(xp, device) | ||
|
||
if sample_weight is None: | ||
sample_weight = np.ones(n_samples, dtype=dtype) | ||
sample_weight = xp.ones(n_samples, dtype=dtype) | ||
elif isinstance(sample_weight, numbers.Number): | ||
sample_weight = np.full(n_samples, sample_weight, dtype=dtype) | ||
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) | ||
else: | ||
if dtype is None: | ||
dtype = [np.float64, np.float32] | ||
dtype = [xp.float64, xp.float32] | ||
sample_weight = check_array( | ||
sample_weight, | ||
accept_sparse=False, | ||
|
@@ -2629,14 +2638,16 @@ def _check_pos_label_consistency(pos_label, y_true): | |
# when elements in the two arrays are not comparable. | ||
if pos_label is None: | ||
# Compute classes only if pos_label is not specified: | ||
classes = np.unique(y_true) | ||
if classes.dtype.kind in "OUS" or not ( | ||
np.array_equal(classes, [0, 1]) | ||
or np.array_equal(classes, [-1, 1]) | ||
or np.array_equal(classes, [0]) | ||
or np.array_equal(classes, [-1]) | ||
or np.array_equal(classes, [1]) | ||
xp, _, device = get_namespace_and_device(y_true) | ||
classes = xp.unique_values(y_true) | ||
if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I seem to recall seeing a similar kind of change in another PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the part from #30878. |
||
xp.all(classes == xp.asarray([0, 1], device=device)) | ||
or xp.all(classes == xp.asarray([-1, 1], device=device)) | ||
or xp.all(classes == xp.asarray([0], device=device)) | ||
or xp.all(classes == xp.asarray([-1], device=device)) | ||
or xp.all(classes == xp.asarray([1], device=device)) | ||
): | ||
classes = _convert_to_numpy(classes, xp=xp) | ||
classes_repr = ", ".join([repr(c) for c in classes.tolist()]) | ||
raise ValueError( | ||
f"y_true takes value in {{{classes_repr}}} and pos_label is not " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit