Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6123e64

Browse filesBrowse files
committed
added _count_zero array api compatible
1 parent 93c1b49 commit 6123e64
Copy full SHA for 6123e64

File tree

3 files changed

+37
-4
lines changed
Filter options

3 files changed

+37
-4
lines changed

‎sklearn/metrics/_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_classification.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from ..utils._array_api import (
3030
_average,
31+
_count_nonzero,
3132
_is_numpy_namespace,
3233
_union1d,
3334
get_namespace,
@@ -221,9 +222,8 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
221222
if _is_numpy_namespace(xp):
222223
differing_labels = count_nonzero(y_true - y_pred, axis=1)
223224
else:
224-
differing_labels = xp.sum(
225-
xp.astype(xp.astype(y_true - y_pred, xp.bool), xp.int8),
226-
axis=1,
225+
differing_labels = _count_nonzero(
226+
y_true - y_pred, xp=xp, device=device, axis=1
227227
)
228228
score = xp.asarray(differing_labels == 0, device=device)
229229
else:

‎sklearn/utils/_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/_array_api.py
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,3 +967,14 @@ def _in1d(ar1, ar2, xp, assume_unique=False, invert=False):
967967
return ret[: ar1.shape[0]]
968968
else:
969969
return xp.take(ret, rev_idx, axis=0)
970+
971+
972+
def _count_nonzero(X, xp, device, axis=None):
973+
"""A variant of `sklearn.utils.sparsefuncs.count_nonzero` for the Array API."""
974+
if axis == -1:
975+
axis = 1
976+
elif axis == -2:
977+
axis = 0
978+
one_scalar = xp.asarray(1, device=device)
979+
zero_scalar = xp.asarray(0, device=device)
980+
return xp.sum(xp.where(X != 0, one_scalar, zero_scalar), axis=axis)

‎sklearn/utils/tests/test_array_api.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_array_api.py
+23-1Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_atol_for_type,
1414
_average,
1515
_convert_to_numpy,
16+
_count_nonzero,
1617
_estimator_with_converted_arrays,
1718
_is_numpy_namespace,
1819
_isin,
@@ -32,7 +33,7 @@
3233
assert_array_equal,
3334
skip_if_array_api_compat_not_configured,
3435
)
35-
from sklearn.utils.fixes import _IS_32BIT
36+
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS
3637

3738

3839
@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
@@ -566,3 +567,24 @@ def test_get_namespace_and_device():
566567
assert namespace is xp_torch
567568
assert is_array_api
568569
assert device == some_torch_tensor.device
570+
571+
572+
@pytest.mark.parametrize(
573+
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
574+
)
575+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
576+
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
577+
def test_count_nonzero(array_namespace, device, _, csr_container, axis):
578+
579+
from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero
580+
581+
xp = _array_api_for_tests(array_namespace, device)
582+
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
583+
expected = sparse_count_nonzero(csr_container(array), axis=axis)
584+
array_xp = xp.asarray(array, device=device)
585+
586+
with config_context(array_api_dispatch=True):
587+
result = _count_nonzero(array_xp, xp=xp, device=device, axis=axis)
588+
589+
assert_array_equal(_convert_to_numpy(result, xp=xp), expected)
590+
assert getattr(array_xp, "device", None) == getattr(result, "device", None)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.