Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 99d8a32

Browse filesBrowse files
Tialojeremiedbb
authored andcommitted
FIX zero_one_loss breaks with multilabel and Array API (#29269)
1 parent 059070b commit 99d8a32
Copy full SHA for 99d8a32

File tree

2 files changed

+34
-2
lines changed
Filter options

2 files changed

+34
-2
lines changed

‎sklearn/metrics/_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_classification.py
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
_average,
4343
_union1d,
4444
get_namespace,
45+
get_namespace_and_device,
4546
)
4647
from ..utils._param_validation import (
4748
Hidden,
@@ -217,13 +218,13 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
217218
>>> accuracy_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
218219
0.5
219220
"""
220-
221+
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
221222
# Compute accuracy for each possible representation
222223
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
223224
check_consistent_length(y_true, y_pred, sample_weight)
224225
if y_type.startswith("multilabel"):
225226
differing_labels = count_nonzero(y_true - y_pred, axis=1)
226-
score = differing_labels == 0
227+
score = xp.asarray(differing_labels == 0, device=device)
227228
else:
228229
score = y_true == y_pred
229230

‎sklearn/metrics/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_common.py
+31Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1826,6 +1826,35 @@ def check_array_api_multiclass_classification_metric(
18261826
)
18271827

18281828

1829+
def check_array_api_multilabel_classification_metric(
1830+
metric, array_namespace, device, dtype_name
1831+
):
1832+
y_true_np = np.array([[1, 1], [0, 1], [0, 0]], dtype=dtype_name)
1833+
y_pred_np = np.array([[1, 1], [1, 1], [1, 1]], dtype=dtype_name)
1834+
1835+
check_array_api_metric(
1836+
metric,
1837+
array_namespace,
1838+
device,
1839+
dtype_name,
1840+
a_np=y_true_np,
1841+
b_np=y_pred_np,
1842+
sample_weight=None,
1843+
)
1844+
1845+
sample_weight = np.array([0.0, 0.1, 2.0], dtype=dtype_name)
1846+
1847+
check_array_api_metric(
1848+
metric,
1849+
array_namespace,
1850+
device,
1851+
dtype_name,
1852+
a_np=y_true_np,
1853+
b_np=y_pred_np,
1854+
sample_weight=sample_weight,
1855+
)
1856+
1857+
18291858
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
18301859
y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name)
18311860
y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name)
@@ -1871,10 +1900,12 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
18711900
accuracy_score: [
18721901
check_array_api_binary_classification_metric,
18731902
check_array_api_multiclass_classification_metric,
1903+
check_array_api_multilabel_classification_metric,
18741904
],
18751905
zero_one_loss: [
18761906
check_array_api_binary_classification_metric,
18771907
check_array_api_multiclass_classification_metric,
1908+
check_array_api_multilabel_classification_metric,
18781909
],
18791910
r2_score: [check_array_api_regression_metric],
18801911
cosine_similarity: [check_array_api_metric_pairwise],

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.