Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 28bd2cf

Browse filesBrowse files
qinhanmin2014Jeremiah Johnson
authored and
Jeremiah Johnson
committed
[MRG+1] Improve the error message for some metrics when the shape of sample_weight is inappropriate (scikit-learn#9903)
1 parent 4d35478 commit 28bd2cf
Copy full SHA for 28bd2cf

File tree

Expand file treeCollapse file tree

3 files changed

+23
-6
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+23
-6
lines changed

‎sklearn/metrics/classification.py

Copy file name to clipboardExpand all lines: sklearn/metrics/classification.py
+8-2Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
174174

175175
# Compute accuracy for each possible representation
176176
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
177+
check_consistent_length(y_true, y_pred, sample_weight)
177178
if y_type.startswith('multilabel'):
178179
differing_labels = count_nonzero(y_true - y_pred, axis=1)
179180
score = differing_labels == 0
@@ -337,7 +338,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
337338
else:
338339
sample_weight = np.asarray(sample_weight)
339340

340-
check_consistent_length(sample_weight, y_true, y_pred)
341+
check_consistent_length(y_true, y_pred, sample_weight)
341342

342343
n_labels = labels.size
343344
label_to_ind = dict((y, x) for x, y in enumerate(labels))
@@ -518,6 +519,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,
518519

519520
# Compute accuracy for each possible representation
520521
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
522+
check_consistent_length(y_true, y_pred, sample_weight)
521523
if y_type.startswith('multilabel'):
522524
with np.errstate(divide='ignore', invalid='ignore'):
523525
# oddly, we may get an "invalid" rather than a "divide" error here
@@ -593,6 +595,7 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None):
593595
-0.33...
594596
"""
595597
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
598+
check_consistent_length(y_true, y_pred, sample_weight)
596599
if y_type not in {"binary", "multiclass"}:
597600
raise ValueError("%s is not supported" % y_type)
598601

@@ -1097,6 +1100,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
10971100
raise ValueError("beta should be >0 in the F-beta score")
10981101

10991102
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1103+
check_consistent_length(y_true, y_pred, sample_weight)
11001104
present_labels = unique_labels(y_true, y_pred)
11011105

11021106
if average == 'binary':
@@ -1624,6 +1628,7 @@ def hamming_loss(y_true, y_pred, labels=None, sample_weight=None,
16241628
labels = classes
16251629

16261630
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1631+
check_consistent_length(y_true, y_pred, sample_weight)
16271632

16281633
if labels is None:
16291634
labels = unique_labels(y_true, y_pred)
@@ -1712,7 +1717,7 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None,
17121717
The logarithm used is the natural logarithm (base-e).
17131718
"""
17141719
y_pred = check_array(y_pred, ensure_2d=False)
1715-
check_consistent_length(y_pred, y_true)
1720+
check_consistent_length(y_pred, y_true, sample_weight)
17161721

17171722
lb = LabelBinarizer()
17181723

@@ -1985,6 +1990,7 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None):
19851990
y_prob = column_or_1d(y_prob)
19861991
assert_all_finite(y_true)
19871992
assert_all_finite(y_prob)
1993+
check_consistent_length(y_true, y_prob, sample_weight)
19881994

19891995
if pos_label is None:
19901996
pos_label = y_true.max()

‎sklearn/metrics/regression.py

Copy file name to clipboardExpand all lines: sklearn/metrics/regression.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def mean_absolute_error(y_true, y_pred,
168168
"""
169169
y_type, y_true, y_pred, multioutput = _check_reg_targets(
170170
y_true, y_pred, multioutput)
171+
check_consistent_length(y_true, y_pred, sample_weight)
171172
output_errors = np.average(np.abs(y_pred - y_true),
172173
weights=sample_weight, axis=0)
173174
if isinstance(multioutput, string_types):
@@ -236,6 +237,7 @@ def mean_squared_error(y_true, y_pred,
236237
"""
237238
y_type, y_true, y_pred, multioutput = _check_reg_targets(
238239
y_true, y_pred, multioutput)
240+
check_consistent_length(y_true, y_pred, sample_weight)
239241
output_errors = np.average((y_true - y_pred) ** 2, axis=0,
240242
weights=sample_weight)
241243
if isinstance(multioutput, string_types):
@@ -306,6 +308,7 @@ def mean_squared_log_error(y_true, y_pred,
306308
"""
307309
y_type, y_true, y_pred, multioutput = _check_reg_targets(
308310
y_true, y_pred, multioutput)
311+
check_consistent_length(y_true, y_pred, sample_weight)
309312

310313
if not (y_true >= 0).all() and not (y_pred >= 0).all():
311314
raise ValueError("Mean Squared Logarithmic Error cannot be used when "
@@ -409,6 +412,7 @@ def explained_variance_score(y_true, y_pred,
409412
"""
410413
y_type, y_true, y_pred, multioutput = _check_reg_targets(
411414
y_true, y_pred, multioutput)
415+
check_consistent_length(y_true, y_pred, sample_weight)
412416

413417
y_diff_avg = np.average(y_true - y_pred, weights=sample_weight, axis=0)
414418
numerator = np.average((y_true - y_pred - y_diff_avg) ** 2,
@@ -528,6 +532,7 @@ def r2_score(y_true, y_pred, sample_weight=None,
528532
"""
529533
y_type, y_true, y_pred, multioutput = _check_reg_targets(
530534
y_true, y_pred, multioutput)
535+
check_consistent_length(y_true, y_pred, sample_weight)
531536

532537
if sample_weight is not None:
533538
sample_weight = column_or_1d(sample_weight)

‎sklearn/metrics/tests/test_common.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_common.py
+10-4Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.datasets import make_multilabel_classification
1010
from sklearn.preprocessing import LabelBinarizer
1111
from sklearn.utils.multiclass import type_of_target
12+
from sklearn.utils.validation import _num_samples
1213
from sklearn.utils.validation import check_random_state
1314
from sklearn.utils import shuffle
1415

@@ -1005,10 +1006,15 @@ def check_sample_weight_invariance(name, metric, y1, y2):
10051006
err_msg="%s sample_weight is not invariant "
10061007
"under scaling" % name)
10071008

1008-
# Check that if sample_weight.shape[0] != y_true.shape[0], it raised an
1009-
# error
1010-
assert_raises(Exception, metric, y1, y2,
1011-
sample_weight=np.hstack([sample_weight, sample_weight]))
1009+
# Check that if number of samples in y_true and sample_weight are not
1010+
# equal, meaningful error is raised.
1011+
error_message = ("Found input variables with inconsistent numbers of "
1012+
"samples: [{}, {}, {}]".format(
1013+
_num_samples(y1), _num_samples(y2),
1014+
_num_samples(sample_weight) * 2))
1015+
assert_raise_message(ValueError, error_message, metric, y1, y2,
1016+
sample_weight=np.hstack([sample_weight,
1017+
sample_weight]))
10121018

10131019

10141020
def test_sample_weight_invariance(n_samples=50):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.