diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index b4940eccec2fc..3da40dc3450d0 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -131,6 +131,7 @@ base estimator also does: Metrics ------- +- :func:`sklearn.metrics.brier_score_loss` (only the binary class case is supported) - :func:`sklearn.metrics.cluster.entropy` - :func:`sklearn.metrics.accuracy_score` - :func:`sklearn.metrics.d2_tweedie_score` diff --git a/doc/whats_new/upcoming_changes/array-api/31191.feature.rst b/doc/whats_new/upcoming_changes/array-api/31191.feature.rst new file mode 100644 index 0000000000000..8c1d1822cd329 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/31191.feature.rst @@ -0,0 +1,2 @@ +- :func:`sklearn.metrics.brier_score_loss` now support Array API compatible inputs for the binary class case. + By :user:`Thomas Li ` diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 30dd53bc16109..f7d32f1c42983 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -194,10 +194,11 @@ def _validate_multiclass_probabilistic_prediction( y_prob = check_array( y_prob, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] ) + xp, _ = get_namespace(y_true, y_prob) - if y_prob.max() > 1: + if xp.max(y_prob) > 1: raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") - if y_prob.min() < 0: + if xp.min(y_prob) < 0: raise ValueError(f"y_prob contains values lower than 0: {y_prob.min()}") check_consistent_length(y_prob, y_true, sample_weight) @@ -3428,6 +3429,7 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos assert_all_finite(y_prob) check_consistent_length(y_prob, y_true, sample_weight) + xp, _, device = get_namespace_and_device(y_true, y_prob, sample_weight) y_type = type_of_target(y_true, input_name="y_true") if y_type != "binary": @@ -3436,16 +3438,16 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos "binary according to the shape of y_prob." ) - if y_prob.max() > 1: + if xp.max(y_prob) > 1: raise ValueError(f"y_prob contains values greater than 1: {y_prob.max()}") - if y_prob.min() < 0: + if xp.min(y_prob) < 0: raise ValueError(f"y_prob contains values less than 0: {y_prob.min()}") # check that pos_label is consistent with y_true try: pos_label = _check_pos_label_consistency(pos_label, y_true) except ValueError: - classes = np.unique(y_true) + classes = xp.unique_values(y_true) if classes.dtype.kind not in ("O", "U", "S"): # for backward compatibility, if classes are not string then # `pos_label` will correspond to the greater label @@ -3454,9 +3456,9 @@ def _validate_binary_probabilistic_prediction(y_true, y_prob, sample_weight, pos raise # convert (n_samples,) to (n_samples, 2) shape - y_true = np.array(y_true == pos_label, int) - transformed_labels = np.column_stack((1 - y_true, y_true)) - y_prob = np.column_stack((1 - y_prob, y_prob)) + y_true = xp.asarray(y_true == pos_label, dtype=xp.int64, device=device) + transformed_labels = xp.stack((1 - y_true, y_true), axis=1) + y_prob = xp.stack((1 - y_prob, y_prob), axis=1) return transformed_labels, y_prob @@ -3589,8 +3591,17 @@ def brier_score_loss( ... ) 0.146... """ + xp, _, device = get_namespace_and_device( + y_true, + y_proba, + sample_weight, + ) y_proba = check_array( - y_proba, ensure_2d=False, dtype=[np.float64, np.float32, np.float16] + y_proba, + ensure_2d=False, + dtype=tuple( + xp.__array_namespace_info__().dtypes(kind="real floating").values() + ), ) if y_proba.ndim == 1 or y_proba.shape[1] == 1: @@ -3601,9 +3612,14 @@ def brier_score_loss( transformed_labels, y_proba = _validate_multiclass_probabilistic_prediction( y_true, y_proba, sample_weight, labels ) - - brier_score = np.average( - np.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight + transformed_labels = xp.asarray(transformed_labels, device=device) + y_proba = xp.asarray(y_proba, device=device) + + # If transformed_labels is integer array, cast it to the floating dtype of + # y_proba + transformed_labels = xp.astype(transformed_labels, y_proba.dtype, device=device) + brier_score = _average( + xp.sum((transformed_labels - y_proba) ** 2, axis=1), weights=sample_weight ) if scale_by_half == "auto": diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 6f9e11d4f4780..37ac926845ace 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2229,6 +2229,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_regression_metric_multioutput, ], sigmoid_kernel: [check_array_api_metric_pairwise], + brier_score_loss: [check_array_api_binary_classification_metric], } diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 116d12fc5e8ad..d9b2829487e55 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -18,7 +18,14 @@ from .. import get_config as _get_config from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning -from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace +from ..utils._array_api import ( + _asarray_with_order, + _convert_to_numpy, + _is_numpy_namespace, + _max_precision_float_dtype, + get_namespace, + get_namespace_and_device, +) from ..utils.deprecation import _deprecate_force_all_finite from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype from ._isfinite import FiniteStatus, cy_isfinite @@ -2148,9 +2155,10 @@ def _check_sample_weight( dtype of the validated `sample_weight`. If None, and `sample_weight` is an array: - - If `sample_weight.dtype` is one of `{np.float64, np.float32}`, + - If `sample_weight.dtype` is one of `{xp.float64, xp.float32}`, then the dtype is preserved. - - Else the output has NumPy's default dtype: `np.float64`. + - Otherwise, the output has the highest precision floating point dtype + supported by the array namespace/device of the input arrays. If `dtype` is not `{np.float32, np.float64, None}`, then output will be `np.float64`. @@ -2169,17 +2177,18 @@ def _check_sample_weight( Validated sample weight. It is guaranteed to be "C" contiguous. """ n_samples = _num_samples(X) + xp, _, device = get_namespace_and_device(X, sample_weight) - if dtype is not None and dtype not in [np.float32, np.float64]: - dtype = np.float64 + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = _max_precision_float_dtype(xp, device) if sample_weight is None: - sample_weight = np.ones(n_samples, dtype=dtype) + sample_weight = xp.ones(n_samples, dtype=dtype) elif isinstance(sample_weight, numbers.Number): - sample_weight = np.full(n_samples, sample_weight, dtype=dtype) + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) else: if dtype is None: - dtype = [np.float64, np.float32] + dtype = [xp.float64, xp.float32] sample_weight = check_array( sample_weight, accept_sparse=False, @@ -2629,14 +2638,16 @@ def _check_pos_label_consistency(pos_label, y_true): # when elements in the two arrays are not comparable. if pos_label is None: # Compute classes only if pos_label is not specified: - classes = np.unique(y_true) - if classes.dtype.kind in "OUS" or not ( - np.array_equal(classes, [0, 1]) - or np.array_equal(classes, [-1, 1]) - or np.array_equal(classes, [0]) - or np.array_equal(classes, [-1]) - or np.array_equal(classes, [1]) + xp, _, device = get_namespace_and_device(y_true) + classes = xp.unique_values(y_true) + if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not ( + xp.all(classes == xp.asarray([0, 1], device=device)) + or xp.all(classes == xp.asarray([-1, 1], device=device)) + or xp.all(classes == xp.asarray([0], device=device)) + or xp.all(classes == xp.asarray([-1], device=device)) + or xp.all(classes == xp.asarray([1], device=device)) ): + classes = _convert_to_numpy(classes, xp=xp) classes_repr = ", ".join([repr(c) for c in classes.tolist()]) raise ValueError( f"y_true takes value in {{{classes_repr}}} and pos_label is not "