From bdc5f4032f4bb78d1a342ee37473ca22f6e12f17 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sun, 13 Apr 2025 16:53:59 -0400 Subject: [PATCH 1/2] ENH: Make brier score Array API compatible --- sklearn/metrics/_classification.py | 12 +++-- sklearn/metrics/tests/test_common.py | 1 + sklearn/utils/validation.py | 69 +++++++++++++++++----------- 3 files changed, 49 insertions(+), 33 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 2a08a1893766e..d8ec60c1fcdec 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -3442,23 +3442,25 @@ def brier_score_loss( f"is {y_type}." ) - if y_proba.max() > 1: + xp, _ = get_namespace(y_true, y_proba, sample_weight) + + if xp.max(y_proba) > 1: raise ValueError("y_proba contains values greater than 1.") - if y_proba.min() < 0: + if xp.min(y_proba) < 0: raise ValueError("y_proba contains values less than 0.") try: pos_label = _check_pos_label_consistency(pos_label, y_true) except ValueError: - classes = np.unique(y_true) + classes = xp.unique(y_true) if classes.dtype.kind not in ("O", "U", "S"): # for backward compatibility, if classes are not string then # `pos_label` will correspond to the greater label pos_label = classes[-1] else: raise - y_true = np.array(y_true == pos_label, int) - return float(np.average((y_true - y_proba) ** 2, weights=sample_weight)) + y_true = xp.astype(y_true == pos_label, xp.int64) + return float(_average((y_true - y_proba) ** 2, weights=sample_weight)) @validate_params( diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 5f44e7b212105..0a027c48b067a 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2195,6 +2195,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_regression_metric_multioutput, ], sigmoid_kernel: [check_array_api_metric_pairwise], + brier_score_loss: [check_array_api_binary_classification_metric], } diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d6e9412712ca8..d9b2829487e55 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -18,7 +18,14 @@ from .. import get_config as _get_config from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning -from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace +from ..utils._array_api import ( + _asarray_with_order, + _convert_to_numpy, + _is_numpy_namespace, + _max_precision_float_dtype, + get_namespace, + get_namespace_and_device, +) from ..utils.deprecation import _deprecate_force_all_finite from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype from ._isfinite import FiniteStatus, cy_isfinite @@ -988,7 +995,7 @@ def is_sparse(dtype): # When all dataframe columns are sparse, convert to a sparse array if hasattr(array, "sparse") and array.ndim > 1: with suppress(ImportError): - from pandas import SparseDtype # noqa: F811 + from pandas import SparseDtype def is_sparse(dtype): return isinstance(dtype, SparseDtype) @@ -1916,7 +1923,7 @@ def type_name(t): expected_include_boundaries = ("left", "right", "both", "neither") if include_boundaries not in expected_include_boundaries: raise ValueError( - f"Unknown value for `include_boundaries`: {repr(include_boundaries)}. " + f"Unknown value for `include_boundaries`: {include_boundaries!r}. " f"Possible values are: {expected_include_boundaries}." ) @@ -2127,7 +2134,7 @@ def _check_psd_eigenvalues(lambdas, enable_warnings=False): def _check_sample_weight( - sample_weight, X, dtype=None, copy=False, ensure_non_negative=False + sample_weight, X, *, dtype=None, ensure_non_negative=False, copy=False ): """Validate sample weights. @@ -2144,18 +2151,23 @@ def _check_sample_weight( X : {ndarray, list, sparse matrix} Input data. + dtype : dtype, default=None + dtype of the validated `sample_weight`. + If None, and `sample_weight` is an array: + + - If `sample_weight.dtype` is one of `{xp.float64, xp.float32}`, + then the dtype is preserved. + - Otherwise, the output has the highest precision floating point dtype + supported by the array namespace/device of the input arrays. + + If `dtype` is not `{np.float32, np.float64, None}`, then output will + be `np.float64`. + ensure_non_negative : bool, default=False, Whether or not the weights are expected to be non-negative. .. versionadded:: 1.0 - dtype : dtype, default=None - dtype of the validated `sample_weight`. - If None, and the input `sample_weight` is an array, the dtype of the - input is preserved; otherwise an array with the default numpy dtype - is be allocated. If `dtype` is not one of `float32`, `float64`, - `None`, the output will be of dtype `float64`. - copy : bool, default=False If True, a copy of sample_weight will be created. @@ -2165,17 +2177,18 @@ def _check_sample_weight( Validated sample weight. It is guaranteed to be "C" contiguous. """ n_samples = _num_samples(X) + xp, _, device = get_namespace_and_device(X, sample_weight) - if dtype is not None and dtype not in [np.float32, np.float64]: - dtype = np.float64 + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = _max_precision_float_dtype(xp, device) if sample_weight is None: - sample_weight = np.ones(n_samples, dtype=dtype) + sample_weight = xp.ones(n_samples, dtype=dtype) elif isinstance(sample_weight, numbers.Number): - sample_weight = np.full(n_samples, sample_weight, dtype=dtype) + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) else: if dtype is None: - dtype = [np.float64, np.float32] + dtype = [xp.float64, xp.float32] sample_weight = check_array( sample_weight, accept_sparse=False, @@ -2311,10 +2324,8 @@ def _check_method_params(X, params, indices=None): method_params_validated = {} for param_key, param_value in params.items(): if ( - not _is_arraylike(param_value) - and not sp.issparse(param_value) - or _num_samples(param_value) != _num_samples(X) - ): + not _is_arraylike(param_value) and not sp.issparse(param_value) + ) or _num_samples(param_value) != _num_samples(X): # Non-indexable pass-through (for now for backward-compatibility). # https://github.com/scikit-learn/scikit-learn/issues/15805 method_params_validated[param_key] = param_value @@ -2627,14 +2638,16 @@ def _check_pos_label_consistency(pos_label, y_true): # when elements in the two arrays are not comparable. if pos_label is None: # Compute classes only if pos_label is not specified: - classes = np.unique(y_true) - if classes.dtype.kind in "OUS" or not ( - np.array_equal(classes, [0, 1]) - or np.array_equal(classes, [-1, 1]) - or np.array_equal(classes, [0]) - or np.array_equal(classes, [-1]) - or np.array_equal(classes, [1]) + xp, _, device = get_namespace_and_device(y_true) + classes = xp.unique_values(y_true) + if (_is_numpy_namespace(xp) and classes.dtype.kind in "OUS") or not ( + xp.all(classes == xp.asarray([0, 1], device=device)) + or xp.all(classes == xp.asarray([-1, 1], device=device)) + or xp.all(classes == xp.asarray([0], device=device)) + or xp.all(classes == xp.asarray([-1], device=device)) + or xp.all(classes == xp.asarray([1], device=device)) ): + classes = _convert_to_numpy(classes, xp=xp) classes_repr = ", ".join([repr(c) for c in classes.tolist()]) raise ValueError( f"y_true takes value in {{{classes_repr}}} and pos_label is not " @@ -2923,7 +2936,7 @@ def validate_data( ) no_val_X = isinstance(X, str) and X == "no_validation" - no_val_y = y is None or isinstance(y, str) and y == "no_validation" + no_val_y = y is None or (isinstance(y, str) and y == "no_validation") if no_val_X and no_val_y: raise ValueError("Validation should be done on X, y or both.") From 402001658d2166bd07042599bf0973469ab7dc04 Mon Sep 17 00:00:00 2001 From: Thomas Li <47963215+lithomas1@users.noreply.github.com> Date: Sun, 13 Apr 2025 17:57:03 -0400 Subject: [PATCH 2/2] whatsnew --- doc/modules/array_api.rst | 1 + doc/whats_new/upcoming_changes/array-api/31191.feature.rst | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/array-api/31191.feature.rst diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index b4940eccec2fc..3da40dc3450d0 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -131,6 +131,7 @@ base estimator also does: Metrics ------- +- :func:`sklearn.metrics.brier_score_loss` (only the binary class case is supported) - :func:`sklearn.metrics.cluster.entropy` - :func:`sklearn.metrics.accuracy_score` - :func:`sklearn.metrics.d2_tweedie_score` diff --git a/doc/whats_new/upcoming_changes/array-api/31191.feature.rst b/doc/whats_new/upcoming_changes/array-api/31191.feature.rst new file mode 100644 index 0000000000000..8c1d1822cd329 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/31191.feature.rst @@ -0,0 +1,2 @@ +- :func:`sklearn.metrics.brier_score_loss` now support Array API compatible inputs for the binary class case. + By :user:`Thomas Li `