Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 212fe51

Browse filesBrowse files
EdAbatiOmarManzoor
authored and
Shruti Nath
committed
FIX: accuracy and zero_loss support for multilabel with Array API (scikit-learn#29336)
Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 345358a commit 212fe51
Copy full SHA for 212fe51

File tree

4 files changed

+72
-7
lines changed
Filter options

4 files changed

+72
-7
lines changed

‎doc/whats_new/v1.5.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.5.rst
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ Changelog
4848
instead of implicitly converting those inputs as regular NumPy arrays.
4949
:pr:`29119` by :user:`Olivier Grisel`.
5050

51-
- |Fix| Fix a regression in :func:`metrics.zero_one_loss` causing an error
52-
for Array API dispatch with multilabel inputs.
53-
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>`.
51+
- |Fix| Fix a regression in :func:`metrics.accuracy_score` and in :func:`metrics.zero_one_loss`
52+
causing an error for Array API dispatch with multilabel inputs.
53+
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>` and :pr:`29336` by :user:`Edoardo Abati <EdAbati>`.
5454

5555
:mod:`sklearn.model_selection`
5656
..............................

‎sklearn/metrics/_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_classification.py
+16-3Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
)
2929
from ..utils._array_api import (
3030
_average,
31+
_count_nonzero,
32+
_is_numpy_namespace,
3133
_union1d,
3234
get_namespace,
3335
get_namespace_and_device,
@@ -85,6 +87,7 @@ def _check_targets(y_true, y_pred):
8587
8688
y_pred : array or indicator matrix
8789
"""
90+
xp, _ = get_namespace(y_true, y_pred)
8891
check_consistent_length(y_true, y_pred)
8992
type_true = type_of_target(y_true, input_name="y_true")
9093
type_pred = type_of_target(y_pred, input_name="y_pred")
@@ -130,8 +133,13 @@ def _check_targets(y_true, y_pred):
130133
y_type = "multiclass"
131134

132135
if y_type.startswith("multilabel"):
133-
y_true = csr_matrix(y_true)
134-
y_pred = csr_matrix(y_pred)
136+
if _is_numpy_namespace(xp):
137+
# XXX: do we really want to sparse-encode multilabel indicators when
138+
# they are passed as a dense arrays? This is not possible for array
139+
# API inputs in general hence we only do it for NumPy inputs. But even
140+
# for NumPy the usefulness is questionable.
141+
y_true = csr_matrix(y_true)
142+
y_pred = csr_matrix(y_pred)
135143
y_type = "multilabel-indicator"
136144

137145
return y_type, y_true, y_pred
@@ -211,7 +219,12 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
211219
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
212220
check_consistent_length(y_true, y_pred, sample_weight)
213221
if y_type.startswith("multilabel"):
214-
differing_labels = count_nonzero(y_true - y_pred, axis=1)
222+
if _is_numpy_namespace(xp):
223+
differing_labels = count_nonzero(y_true - y_pred, axis=1)
224+
else:
225+
differing_labels = _count_nonzero(
226+
y_true - y_pred, xp=xp, device=device, axis=1
227+
)
215228
score = xp.asarray(differing_labels == 0, device=device)
216229
else:
217230
score = y_true == y_pred

‎sklearn/utils/_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/_array_api.py
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,3 +967,20 @@ def _in1d(ar1, ar2, xp, assume_unique=False, invert=False):
967967
return ret[: ar1.shape[0]]
968968
else:
969969
return xp.take(ret, rev_idx, axis=0)
970+
971+
972+
def _count_nonzero(X, xp, device, axis=None, sample_weight=None):
973+
"""A variant of `sklearn.utils.sparsefuncs.count_nonzero` for the Array API.
974+
975+
It only supports 2D arrays.
976+
"""
977+
assert X.ndim == 2
978+
979+
weights = xp.ones_like(X, device=device)
980+
if sample_weight is not None:
981+
sample_weight = xp.asarray(sample_weight, device=device)
982+
sample_weight = xp.reshape(sample_weight, (sample_weight.shape[0], 1))
983+
weights = xp.astype(weights, sample_weight.dtype) * sample_weight
984+
985+
zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
986+
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)

‎sklearn/utils/tests/test_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_array_api.py
+36-1Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_atol_for_type,
1414
_average,
1515
_convert_to_numpy,
16+
_count_nonzero,
1617
_estimator_with_converted_arrays,
1718
_is_numpy_namespace,
1819
_isin,
@@ -32,7 +33,7 @@
3233
assert_array_equal,
3334
skip_if_array_api_compat_not_configured,
3435
)
35-
from sklearn.utils.fixes import _IS_32BIT
36+
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS
3637

3738

3839
@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
@@ -566,3 +567,37 @@ def test_get_namespace_and_device():
566567
assert namespace is xp_torch
567568
assert is_array_api
568569
assert device == some_torch_tensor.device
570+
571+
572+
@pytest.mark.parametrize(
573+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
574+
)
575+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
576+
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
577+
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
578+
def test_count_nonzero(
579+
array_namespace, device, dtype_name, csr_container, axis, sample_weight_type
580+
):
581+
582+
from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero
583+
584+
xp = _array_api_for_tests(array_namespace, device)
585+
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
586+
if sample_weight_type == "int":
587+
sample_weight = numpy.asarray([1, 2, 2, 3, 1])
588+
elif sample_weight_type == "float":
589+
sample_weight = numpy.asarray([0.5, 1.5, 0.8, 3.2, 2.4], dtype=dtype_name)
590+
else:
591+
sample_weight = None
592+
expected = sparse_count_nonzero(
593+
csr_container(array), axis=axis, sample_weight=sample_weight
594+
)
595+
array_xp = xp.asarray(array, device=device)
596+
597+
with config_context(array_api_dispatch=True):
598+
result = _count_nonzero(
599+
array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight
600+
)
601+
602+
assert_allclose(_convert_to_numpy(result, xp=xp), expected)
603+
assert getattr(array_xp, "device", None) == getattr(result, "device", None)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.