Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ab21254

Browse filesBrowse files
joclementglemaitre
authored andcommitted
FIX mislabelling multiclass target when labels is provided in top_k_accuracy_score (#19721)
1 parent 0143fe4 commit ab21254
Copy full SHA for ab21254

File tree

Expand file treeCollapse file tree

3 files changed

+35
-1
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+35
-1
lines changed

‎doc/whats_new/v0.24.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.24.rst
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ Changelog
5959
- |Fix|: Fixed a bug in :class:`linear_model.LogisticRegression`: the
6060
sample_weight object is not modified anymore. :pr:`19182` by
6161
:user:`Yosuke KOBAYASHI <m7142yosuke>`.
62+
63+
:mod:`sklearn.metrics`
64+
......................
65+
66+
- |Fix| :func:`metrics.top_k_accuracy_score` now supports multiclass
67+
problems where only two classes appear in `y_true` and all the classes
68+
are specified in `labels`.
69+
:pr:`19721` by :user:`Joris Clement <flyingdutchman23>`.
6270

6371
:mod:`sklearn.model_selection`
6472
..............................

‎sklearn/metrics/_ranking.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_ranking.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def top_k_accuracy_score(y_true, y_score, *, k=2, normalize=True,
15891589
non-thresholded decision values (as returned by
15901590
:term:`decision_function` on some classifiers). The binary case expects
15911591
scores with shape (n_samples,) while the multiclass case expects scores
1592-
with shape (n_samples, n_classes). In the nulticlass case, the order of
1592+
with shape (n_samples, n_classes). In the multiclass case, the order of
15931593
the class scores must correspond to the order of ``labels``, if
15941594
provided, or else to the numerical or lexicographical order of the
15951595
labels in ``y_true``.
@@ -1646,6 +1646,8 @@ def top_k_accuracy_score(y_true, y_score, *, k=2, normalize=True,
16461646
y_true = check_array(y_true, ensure_2d=False, dtype=None)
16471647
y_true = column_or_1d(y_true)
16481648
y_type = type_of_target(y_true)
1649+
if y_type == "binary" and labels is not None and len(labels) > 2:
1650+
y_type = "multiclass"
16491651
y_score = check_array(y_score, ensure_2d=False)
16501652
y_score = column_or_1d(y_score) if y_type == 'binary' else y_score
16511653
check_consistent_length(y_true, y_score, sample_weight)

‎sklearn/metrics/tests/test_ranking.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_ranking.py
+24Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,30 @@ def test_top_k_accuracy_score_binary(y_score, k, true_score):
16501650
assert score == score_acc == pytest.approx(true_score)
16511651

16521652

1653+
@pytest.mark.parametrize('y_true, true_score, labels', [
1654+
(np.array([0, 1, 1, 2]), 0.75, [0, 1, 2, 3]),
1655+
(np.array([0, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
1656+
(np.array([1, 1, 1, 1]), 0.5, [0, 1, 2, 3]),
1657+
(np.array(['a', 'e', 'e', 'a']), 0.75, ['a', 'b', 'd', 'e']),
1658+
])
1659+
@pytest.mark.parametrize("labels_as_ndarray", [True, False])
1660+
def test_top_k_accuracy_score_multiclass_with_labels(
1661+
y_true, true_score, labels, labels_as_ndarray
1662+
):
1663+
"""Test when labels and y_score are multiclass."""
1664+
if labels_as_ndarray:
1665+
labels = np.asarray(labels)
1666+
y_score = np.array([
1667+
[0.4, 0.3, 0.2, 0.1],
1668+
[0.1, 0.3, 0.4, 0.2],
1669+
[0.4, 0.1, 0.2, 0.3],
1670+
[0.3, 0.2, 0.4, 0.1],
1671+
])
1672+
1673+
score = top_k_accuracy_score(y_true, y_score, k=2, labels=labels)
1674+
assert score == pytest.approx(true_score)
1675+
1676+
16531677
def test_top_k_accuracy_score_increasing():
16541678
# Make sure increasing k leads to a higher score
16551679
X, y = datasets.make_classification(n_classes=10, n_samples=1000,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.