Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 851c0d6

Browse filesBrowse files
EdAbatiOmarManzoor
authored andcommitted
FIX: accuracy and zero_loss support for multilabel with Array API (#29336)
Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 99d8a32 commit 851c0d6
Copy full SHA for 851c0d6

File tree

Expand file treeCollapse file tree

4 files changed

+72
-7
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+72
-7
lines changed

‎doc/whats_new/v1.5.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.5.rst
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ Changelog
4848
instead of implicitly converting those inputs as regular NumPy arrays.
4949
:pr:`29119` by :user:`Olivier Grisel`.
5050

51-
- |Fix| Fix a regression in :func:`metrics.zero_one_loss` causing an error
52-
for Array API dispatch with multilabel inputs.
53-
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>`.
51+
- |Fix| Fix a regression in :func:`metrics.accuracy_score` and in :func:`metrics.zero_one_loss`
52+
causing an error for Array API dispatch with multilabel inputs.
53+
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>` and :pr:`29336` by :user:`Edoardo Abati <EdAbati>`.
5454

5555
:mod:`sklearn.model_selection`
5656
..............................

‎sklearn/metrics/_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_classification.py
+16-3Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
)
4141
from ..utils._array_api import (
4242
_average,
43+
_count_nonzero,
44+
_is_numpy_namespace,
4345
_union1d,
4446
get_namespace,
4547
get_namespace_and_device,
@@ -97,6 +99,7 @@ def _check_targets(y_true, y_pred):
9799
98100
y_pred : array or indicator matrix
99101
"""
102+
xp, _ = get_namespace(y_true, y_pred)
100103
check_consistent_length(y_true, y_pred)
101104
type_true = type_of_target(y_true, input_name="y_true")
102105
type_pred = type_of_target(y_pred, input_name="y_pred")
@@ -142,8 +145,13 @@ def _check_targets(y_true, y_pred):
142145
y_type = "multiclass"
143146

144147
if y_type.startswith("multilabel"):
145-
y_true = csr_matrix(y_true)
146-
y_pred = csr_matrix(y_pred)
148+
if _is_numpy_namespace(xp):
149+
# XXX: do we really want to sparse-encode multilabel indicators when
150+
# they are passed as a dense arrays? This is not possible for array
151+
# API inputs in general hence we only do it for NumPy inputs. But even
152+
# for NumPy the usefulness is questionable.
153+
y_true = csr_matrix(y_true)
154+
y_pred = csr_matrix(y_pred)
147155
y_type = "multilabel-indicator"
148156

149157
return y_type, y_true, y_pred
@@ -223,7 +231,12 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
223231
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
224232
check_consistent_length(y_true, y_pred, sample_weight)
225233
if y_type.startswith("multilabel"):
226-
differing_labels = count_nonzero(y_true - y_pred, axis=1)
234+
if _is_numpy_namespace(xp):
235+
differing_labels = count_nonzero(y_true - y_pred, axis=1)
236+
else:
237+
differing_labels = _count_nonzero(
238+
y_true - y_pred, xp=xp, device=device, axis=1
239+
)
227240
score = xp.asarray(differing_labels == 0, device=device)
228241
else:
229242
score = y_true == y_pred

‎sklearn/utils/_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/_array_api.py
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,3 +841,20 @@ def indexing_dtype(xp):
841841
# TODO: once sufficiently adopted, we might want to instead rely on the
842842
# newer inspection API: https://github.com/data-apis/array-api/issues/640
843843
return xp.asarray(0).dtype
844+
845+
846+
def _count_nonzero(X, xp, device, axis=None, sample_weight=None):
847+
"""A variant of `sklearn.utils.sparsefuncs.count_nonzero` for the Array API.
848+
849+
It only supports 2D arrays.
850+
"""
851+
assert X.ndim == 2
852+
853+
weights = xp.ones_like(X, device=device)
854+
if sample_weight is not None:
855+
sample_weight = xp.asarray(sample_weight, device=device)
856+
sample_weight = xp.reshape(sample_weight, (sample_weight.shape[0], 1))
857+
weights = xp.astype(weights, sample_weight.dtype) * sample_weight
858+
859+
zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
860+
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)

‎sklearn/utils/tests/test_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_array_api.py
+36-1Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_atol_for_type,
1414
_average,
1515
_convert_to_numpy,
16+
_count_nonzero,
1617
_estimator_with_converted_arrays,
1718
_is_numpy_namespace,
1819
_nanmax,
@@ -30,7 +31,7 @@
3031
_array_api_for_tests,
3132
skip_if_array_api_compat_not_configured,
3233
)
33-
from sklearn.utils.fixes import _IS_32BIT
34+
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS
3435

3536

3637
@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
@@ -530,3 +531,37 @@ def test_get_namespace_and_device():
530531
assert namespace is xp_torch
531532
assert is_array_api
532533
assert device == some_torch_tensor.device
534+
535+
536+
@pytest.mark.parametrize(
537+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
538+
)
539+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
540+
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
541+
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
542+
def test_count_nonzero(
543+
array_namespace, device, dtype_name, csr_container, axis, sample_weight_type
544+
):
545+
546+
from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero
547+
548+
xp = _array_api_for_tests(array_namespace, device)
549+
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
550+
if sample_weight_type == "int":
551+
sample_weight = numpy.asarray([1, 2, 2, 3, 1])
552+
elif sample_weight_type == "float":
553+
sample_weight = numpy.asarray([0.5, 1.5, 0.8, 3.2, 2.4], dtype=dtype_name)
554+
else:
555+
sample_weight = None
556+
expected = sparse_count_nonzero(
557+
csr_container(array), axis=axis, sample_weight=sample_weight
558+
)
559+
array_xp = xp.asarray(array, device=device)
560+
561+
with config_context(array_api_dispatch=True):
562+
result = _count_nonzero(
563+
array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight
564+
)
565+
566+
assert_allclose(_convert_to_numpy(result, xp=xp), expected)
567+
assert getattr(array_xp, "device", None) == getattr(result, "device", None)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.