Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 4be28d4

Browse filesBrowse files
ashah002glemaitrejeremiedbb
authored
MAINT Parameter validation for sklearn.metrics.d2_pinball_score (#25414)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent ba16dbe commit 4be28d4
Copy full SHA for 4be28d4

File tree

3 files changed

+15
-11
lines changed
Filter options

3 files changed

+15
-11
lines changed

‎sklearn/metrics/_regression.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_regression.py
+14-8Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,18 @@ def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0):
13041304
return 1 - numerator / denominator
13051305

13061306

1307+
@validate_params(
1308+
{
1309+
"y_true": ["array-like"],
1310+
"y_pred": ["array-like"],
1311+
"sample_weight": ["array-like", None],
1312+
"alpha": [Interval(Real, 0, 1, closed="both")],
1313+
"multioutput": [
1314+
StrOptions({"raw_values", "uniform_average"}),
1315+
"array-like",
1316+
],
1317+
}
1318+
)
13071319
def d2_pinball_score(
13081320
y_true, y_pred, *, sample_weight=None, alpha=0.5, multioutput="uniform_average"
13091321
):
@@ -1327,7 +1339,7 @@ def d2_pinball_score(
13271339
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
13281340
Estimated target values.
13291341
1330-
sample_weight : array-like of shape (n_samples,), optional
1342+
sample_weight : array-like of shape (n_samples,), default=None
13311343
Sample weights.
13321344
13331345
alpha : float, default=0.5
@@ -1434,15 +1446,9 @@ def d2_pinball_score(
14341446
if multioutput == "raw_values":
14351447
# return scores individually
14361448
return output_scores
1437-
elif multioutput == "uniform_average":
1449+
else: # multioutput == "uniform_average"
14381450
# passing None as weights to np.average results in uniform mean
14391451
avg_weights = None
1440-
else:
1441-
raise ValueError(
1442-
"multioutput is expected to be 'raw_values' "
1443-
"or 'uniform_average' but we got %r"
1444-
" instead." % multioutput
1445-
)
14461452
else:
14471453
avg_weights = multioutput
14481454

‎sklearn/metrics/tests/test_regression.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_regression.py
-3Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,6 @@ def test_regression_multioutput_array():
351351
with pytest.raises(ValueError, match=err_msg):
352352
mean_pinball_loss(y_true, y_pred, multioutput="variance_weighted")
353353

354-
with pytest.raises(ValueError, match=err_msg):
355-
d2_pinball_score(y_true, y_pred, multioutput="variance_weighted")
356-
357354
pbl = mean_pinball_loss(y_true, y_pred, multioutput="raw_values")
358355
mape = mean_absolute_percentage_error(y_true, y_pred, multioutput="raw_values")
359356
r = r2_score(y_true, y_pred, multioutput="raw_values")

‎sklearn/tests/test_public_functions.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_public_functions.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def _check_function_param_validation(
117117
"sklearn.metrics.cluster.contingency_matrix",
118118
"sklearn.metrics.cohen_kappa_score",
119119
"sklearn.metrics.confusion_matrix",
120+
"sklearn.metrics.d2_pinball_score",
120121
"sklearn.metrics.det_curve",
121122
"sklearn.metrics.hamming_loss",
122123
"sklearn.metrics.mean_absolute_error",

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.