-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH: Make brier_score_loss Array API compatible #31191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
- :func:`sklearn.metrics.brier_score_loss` now support Array API compatible inputs for the binary class case. | ||
By :user:`Thomas Li <lithomas1>` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,10 +194,11 @@ def _validate_multiclass_probabilistic_prediction( | |
y_prob = check_array( | ||
y_prob, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] | ||
) | ||
xp, _ = get_namespace(y_true, y_prob) | ||
|
||
if y_prob.max() > 1: | ||
if xp.max(y_prob) > 1: | ||
raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") | ||
if y_prob.min() < 0: | ||
if xp.min(y_prob) < 0: | ||
raise ValueError(f"y_prob contains values lower than 0: {y_prob.min()}") | ||
|
||
check_consistent_length(y_prob, y_true, sample_weight) | ||
|
@@ -3428,6 +3429,7 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos | |
assert_all_finite(y_prob) | ||
|
||
check_consistent_length(y_prob, y_true, sample_weight) | ||
xp, _, device = get_namespace_and_device(y_true, y_prob, sample_weight) | ||
|
||
y_type = type_of_target(y_true, input_name="y_true") | ||
if y_type != "binary": | ||
|
@@ -3436,16 +3438,16 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos | |
"binary according to the shape of y_prob." | ||
) | ||
|
||
if y_prob.max() > 1: | ||
if xp.max(y_prob) > 1: | ||
raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") | ||
if y_prob.min() < 0: | ||
if xp.min(y_prob) < 0: | ||
raise ValueError(f"y_prob contains values less than 0: {y_prob.min()}") | ||
|
||
# check that pos_label is consistent with y_true | ||
try: | ||
pos_label = _check_pos_label_consistency(pos_label, y_true) | ||
except ValueError: | ||
classes = np.unique(y_true) | ||
classes = xp.unique_values(y_true) | ||
if classes.dtype.kind not in ("O", "U", "S"): | ||
# for backward compatibility, if classes are not string then | ||
# `pos_label` will correspond to the greater label | ||
|
@@ -3454,9 +3456,9 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos | |
raise | ||
|
||
# convert (n_samples,) to (n_samples, 2) shape | ||
y_true = np.array(y_true == pos_label, int) | ||
transformed_labels = np.column_stack((1 - y_true, y_true)) | ||
y_prob = np.column_stack((1 - y_prob, y_prob)) | ||
y_true = xp.asarray(y_true == pos_label, dtype=xp.int64, device=device) | ||
transformed_labels = xp.stack((1 - y_true, y_true), axis=1) | ||
y_prob = xp.stack((1 - y_prob, y_prob), axis=1) | ||
|
||
return transformed_labels, y_prob | ||
|
||
|
@@ -3589,8 +3591,17 @@ def brier_score_loss( | |
... ) | ||
0.146... | ||
""" | ||
xp, _, device = get_namespace_and_device( | ||
y_true, | ||
y_proba, | ||
sample_weight, | ||
) | ||
y_proba = check_array( | ||
y_proba, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] | ||
y_proba, | ||
ensure_2d=False, | ||
dtype=tuple( | ||
xp.__array_namespace_info__().dtypes(kind="real floating").values() | ||
), | ||
) | ||
|
||
if y_proba.ndim == 1 or y_proba.shape[1] == 1: | ||
|
@@ -3601,9 +3612,14 @@ def brier_score_loss( | |
transformed_labels, y_proba = _validate_multiclass_probabilistic_prediction( | ||
y_true, y_proba, sample_weight, labels | ||
) | ||
|
||
brier_score = np.average( | ||
np.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight | ||
transformed_labels = xp.asarray(transformed_labels, device=device) | ||
y_proba = xp.asarray(y_proba, device=device) | ||
Comment on lines
+3615
to
+3616
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't these be on the device already. I think y_proba might be shifted to cpu because of the check_array function but assuming that y_true and y_prob are on the expected device transformed_labels should be on the device as well. Or is this just handling for the array-api-strict? |
||
|
||
# If transformed_labels is integer array, cast it to the floating dtype of | ||
# y_proba | ||
transformed_labels = xp.astype(transformed_labels, y_proba.dtype, device=device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we are again moving it on the device? |
||
brier_score = _average( | ||
xp.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight | ||
) | ||
|
||
if scale_by_half == "auto": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,14 @@ | |
|
||
from .. import get_config as _get_config | ||
from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning | ||
from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace | ||
from ..utils._array_api import ( | ||
_asarray_with_order, | ||
_convert_to_numpy, | ||
_is_numpy_namespace, | ||
_max_precision_float_dtype, | ||
get_namespace, | ||
get_namespace_and_device, | ||
) | ||
from ..utils.deprecation import _deprecate_force_all_finite | ||
from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype | ||
from ._isfinite import FiniteStatus, cy_isfinite | ||
|
@@ -2148,9 +2155,10 @@ def _check_sample_weight( | |
dtype of the validated `sample_weight`. | ||
If None, and `sample_weight` is an array: | ||
|
||
- If `sample_weight.dtype` is one of `{np.float64, np.float32}`, | ||
- If `sample_weight.dtype` is one of `{xp.float64, xp.float32}`, | ||
then the dtype is preserved. | ||
- Else the output has NumPy's default dtype: `np.float64`. | ||
- Otherwise, the output has the highest precision floating point dtype | ||
supported by the array namespace/device of the input arrays. | ||
|
||
If `dtype` is not `{np.float32, np.float64, None}`, then output will | ||
be `np.float64`. | ||
|
@@ -2169,17 +2177,18 @@ def _check_sample_weight( | |
Validated sample weight. It is guaranteed to be "C" contiguous. | ||
""" | ||
n_samples = _num_samples(X) | ||
xp, _, device = get_namespace_and_device(X, sample_weight) | ||
|
||
if dtype is not None and dtype not in [np.float32, np.float64]: | ||
dtype = np.float64 | ||
if dtype is not None and dtype not in [xp.float32, xp.float64]: | ||
dtype = _max_precision_float_dtype(xp, device) | ||
|
||
if sample_weight is None: | ||
sample_weight = np.ones(n_samples, dtype=dtype) | ||
sample_weight = xp.ones(n_samples, dtype=dtype) | ||
elif isinstance(sample_weight, numbers.Number): | ||
sample_weight = np.full(n_samples, sample_weight, dtype=dtype) | ||
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) | ||
else: | ||
if dtype is None: | ||
dtype = [np.float64, np.float32] | ||
dtype = [xp.float64, xp.float32] | ||
sample_weight = check_array( | ||
sample_weight, | ||
accept_sparse=False, | ||
|
@@ -2629,14 +2638,16 @@ def _check_pos_label_consistency(pos_label, y_true): | |
# when elements in the two arrays are not comparable. | ||
if pos_label is None: | ||
# Compute classes only if pos_label is not specified: | ||
classes = np.unique(y_true) | ||
if classes.dtype.kind in "OUS" or not ( | ||
np.array_equal(classes, [0, 1]) | ||
or np.array_equal(classes, [-1, 1]) | ||
or np.array_equal(classes, [0]) | ||
or np.array_equal(classes, [-1]) | ||
or np.array_equal(classes, [1]) | ||
xp, _, device = get_namespace_and_device(y_true) | ||
classes = xp.unique_values(y_true) | ||
if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I seem to recall seeing a similar kind of change in another PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the part from #30878. |
||
xp.all(classes == xp.asarray([0, 1], device=device)) | ||
or xp.all(classes == xp.asarray([-1, 1], device=device)) | ||
or xp.all(classes == xp.asarray([0], device=device)) | ||
or xp.all(classes == xp.asarray([-1], device=device)) | ||
or xp.all(classes == xp.asarray([1], device=device)) | ||
): | ||
classes = _convert_to_numpy(classes, xp=xp) | ||
classes_repr = ", ".join([repr(c) for c in classes.tolist()]) | ||
raise ValueError( | ||
f"y_true takes value in {{{classes_repr}}} and pos_label is not " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this change aside from just replacing np with xp in the floatdata types? Is there some other float dtype that we want to support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this since some libraries like PyTorch MPS don't support xp.float32.
(Also float16 is not in the array API standard. Should we make a special exception for np.float16?)
Maybe it would be good to put this in a helper in the array API utils module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far, we used
_find_matching_floating_dtype
for this use case. We could update that utility to leverage__array_namespace_info__
as you did here.We could also improve
check_array
to acceptdtype="floating"
and do device/namespace specific conversion when provided with integer inputs.