Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f4e692c

Browse filesBrowse files
ENH Raises error in hinge_loss when 'pred_decision' is invalid (#19643)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 15fd026 commit f4e692c
Copy full SHA for f4e692c

File tree

3 files changed

+53
-4
lines changed
Filter options

3 files changed

+53
-4
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ Changelog
165165
class methods and will be removed in 1.2.
166166
:pr:`18543` by `Guillaume Lemaitre`_.
167167

168+
- |Enhancement| A fix to raise an error in :func:`metrics.hinge_loss` when
169+
``pred_decision`` is 1d whereas it is a multiclass classification or when
170+
``pred_decision`` parameter is not consistent with the ``labels`` parameter.
171+
:pr:`19643` by :user:`Pierre Attard <PierreAttard>`.
172+
168173
- |Feature| :func:`metrics.mean_pinball_loss` exposes the pinball loss for
169174
quantile regression. :pr:`19415` by :user:`Xavier Dupré <sdpython>`
170175
and :user:`Oliver Grisel <ogrisel>`.

‎sklearn/metrics/_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_classification.py
+22-4Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,11 +2378,29 @@ def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None):
23782378
pred_decision = check_array(pred_decision, ensure_2d=False)
23792379
y_true = column_or_1d(y_true)
23802380
y_true_unique = np.unique(labels if labels is not None else y_true)
2381+
23812382
if y_true_unique.size > 2:
2382-
if (labels is None and pred_decision.ndim > 1 and
2383-
(np.size(y_true_unique) != pred_decision.shape[1])):
2384-
raise ValueError("Please include all labels in y_true "
2385-
"or pass labels as third argument")
2383+
2384+
if pred_decision.ndim <= 1:
2385+
raise ValueError("The shape of pred_decision cannot be 1d array"
2386+
"with a multiclass target. pred_decision shape "
2387+
"must be (n_samples, n_classes), that is "
2388+
f"({y_true.shape[0]}, {y_true_unique.size})."
2389+
f" Got: {pred_decision.shape}")
2390+
2391+
# pred_decision.ndim > 1 is true
2392+
if y_true_unique.size != pred_decision.shape[1]:
2393+
if labels is None:
2394+
raise ValueError("Please include all labels in y_true "
2395+
"or pass labels as third argument")
2396+
else:
2397+
raise ValueError("The shape of pred_decision is not "
2398+
"consistent with the number of classes. "
2399+
"With a multiclass target, pred_decision "
2400+
"shape must be "
2401+
"(n_samples, n_classes), that is "
2402+
f"({y_true.shape[0]}, {y_true_unique.size}). "
2403+
f"Got: {pred_decision.shape}")
23862404
if labels is None:
23872405
labels = y_true_unique
23882406
le = LabelEncoder()

‎sklearn/metrics/tests/test_classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_classification.py
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import chain
55
from itertools import permutations
66
import warnings
7+
import re
78

89
import numpy as np
910
from scipy import linalg
@@ -2135,6 +2136,31 @@ def test_hinge_loss_multiclass_missing_labels_with_labels_none():
21352136
hinge_loss(y_true, pred_decision)
21362137

21372138

2139+
def test_hinge_loss_multiclass_no_consistent_pred_decision_shape():
2140+
# test for inconsistency between multiclass problem and pred_decision
2141+
# argument
2142+
y_true = np.array([2, 1, 0, 1, 0, 1, 1])
2143+
pred_decision = np.array([0, 1, 2, 1, 0, 2, 1])
2144+
error_message = ("The shape of pred_decision cannot be 1d array"
2145+
"with a multiclass target. pred_decision shape "
2146+
"must be (n_samples, n_classes), that is "
2147+
"(7, 3). Got: (7,)")
2148+
with pytest.raises(ValueError, match=re.escape(error_message)):
2149+
hinge_loss(y_true=y_true, pred_decision=pred_decision)
2150+
2151+
# test for inconsistency between pred_decision shape and labels number
2152+
pred_decision = np.array([[0, 1], [0, 1], [0, 1], [0, 1],
2153+
[2, 0], [0, 1], [1, 0]])
2154+
labels = [0, 1, 2]
2155+
error_message = ("The shape of pred_decision is not "
2156+
"consistent with the number of classes. "
2157+
"With a multiclass target, pred_decision "
2158+
"shape must be (n_samples, n_classes), that is "
2159+
"(7, 3). Got: (7, 2)")
2160+
with pytest.raises(ValueError, match=re.escape(error_message)):
2161+
hinge_loss(y_true=y_true, pred_decision=pred_decision, labels=labels)
2162+
2163+
21382164
def test_hinge_loss_multiclass_with_missing_labels():
21392165
pred_decision = np.array([
21402166
[+0.36, -0.17, -0.58, -0.99],

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.