Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 65d42c9

Browse filesBrowse files
glemaitrejeremiedbbadrinjalali
authored
MAINT refactor scorer using _get_response_values (#26037)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 948b49d commit 65d42c9
Copy full SHA for 65d42c9

File tree

4 files changed

+61
-100
lines changed
Filter options

4 files changed

+61
-100
lines changed

‎sklearn/metrics/_scorer.py

Copy file name to clipboardExpand all lines: sklearn/metrics/_scorer.py
+26-65Lines changed: 26 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
# Arnaud Joly <arnaud.v.joly@gmail.com>
1919
# License: Simplified BSD
2020

21-
from functools import partial
2221
from collections import Counter
22+
from inspect import signature
23+
from functools import partial
2324
from traceback import format_exc
2425

2526
import numpy as np
@@ -64,20 +65,23 @@
6465

6566
from ..utils.multiclass import type_of_target
6667
from ..base import is_regressor
68+
from ..utils._response import _get_response_values
6769
from ..utils._param_validation import HasMethods, StrOptions, validate_params
6870

6971

70-
def _cached_call(cache, estimator, method, *args, **kwargs):
72+
def _cached_call(cache, estimator, response_method, *args, **kwargs):
7173
"""Call estimator with method and args and kwargs."""
72-
if cache is None:
73-
return getattr(estimator, method)(*args, **kwargs)
74+
if cache is not None and response_method in cache:
75+
return cache[response_method]
76+
77+
result, _ = _get_response_values(
78+
estimator, *args, response_method=response_method, **kwargs
79+
)
80+
81+
if cache is not None:
82+
cache[response_method] = result
7483

75-
try:
76-
return cache[method]
77-
except KeyError:
78-
result = getattr(estimator, method)(*args, **kwargs)
79-
cache[method] = result
80-
return result
84+
return result
8185

8286

8387
class _MultimetricScorer:
@@ -162,40 +166,13 @@ def __init__(self, score_func, sign, kwargs):
162166
self._score_func = score_func
163167
self._sign = sign
164168

165-
@staticmethod
166-
def _check_pos_label(pos_label, classes):
167-
if pos_label not in list(classes):
168-
raise ValueError(f"pos_label={pos_label} is not a valid label: {classes}")
169-
170-
def _select_proba_binary(self, y_pred, classes):
171-
"""Select the column of the positive label in `y_pred` when
172-
probabilities are provided.
173-
174-
Parameters
175-
----------
176-
y_pred : ndarray of shape (n_samples, n_classes)
177-
The prediction given by `predict_proba`.
178-
179-
classes : ndarray of shape (n_classes,)
180-
The class labels for the estimator.
181-
182-
Returns
183-
-------
184-
y_pred : ndarray of shape (n_samples,)
185-
Probability predictions of the positive class.
186-
"""
187-
if y_pred.shape[1] == 2:
188-
pos_label = self._kwargs.get("pos_label", classes[1])
189-
self._check_pos_label(pos_label, classes)
190-
col_idx = np.flatnonzero(classes == pos_label)[0]
191-
return y_pred[:, col_idx]
192-
193-
err_msg = (
194-
f"Got predict_proba of shape {y_pred.shape}, but need "
195-
f"classifier with two classes for {self._score_func.__name__} "
196-
"scoring"
197-
)
198-
raise ValueError(err_msg)
169+
def _get_pos_label(self):
170+
if "pos_label" in self._kwargs:
171+
return self._kwargs["pos_label"]
172+
score_func_params = signature(self._score_func).parameters
173+
if "pos_label" in score_func_params:
174+
return score_func_params["pos_label"].default
175+
return None
199176

200177
def __repr__(self):
201178
kwargs_string = "".join(
@@ -311,14 +288,7 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
311288
score : float
312289
Score function applied to prediction of estimator on X.
313290
"""
314-
315-
y_type = type_of_target(y)
316-
y_pred = method_caller(clf, "predict_proba", X)
317-
if y_type == "binary" and y_pred.shape[1] <= 2:
318-
# `y_type` could be equal to "binary" even in a multi-class
319-
# problem: (when only 2 class are given to `y_true` during scoring)
320-
# Thus, we need to check for the shape of `y_pred`.
321-
y_pred = self._select_proba_binary(y_pred, clf.classes_)
291+
y_pred = method_caller(clf, "predict_proba", X, pos_label=self._get_pos_label())
322292
if sample_weight is not None:
323293
return self._sign * self._score_func(
324294
y, y_pred, sample_weight=sample_weight, **self._kwargs
@@ -369,26 +339,17 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
369339
if is_regressor(clf):
370340
y_pred = method_caller(clf, "predict", X)
371341
else:
342+
pos_label = self._get_pos_label()
372343
try:
373-
y_pred = method_caller(clf, "decision_function", X)
344+
y_pred = method_caller(clf, "decision_function", X, pos_label=pos_label)
374345

375346
if isinstance(y_pred, list):
376347
# For multi-output multi-class estimator
377348
y_pred = np.vstack([p for p in y_pred]).T
378-
elif y_type == "binary" and "pos_label" in self._kwargs:
379-
self._check_pos_label(self._kwargs["pos_label"], clf.classes_)
380-
if self._kwargs["pos_label"] == clf.classes_[0]:
381-
# The implicit positive class of the binary classifier
382-
# does not match `pos_label`: we need to invert the
383-
# predictions
384-
y_pred *= -1
385349

386350
except (NotImplementedError, AttributeError):
387-
y_pred = method_caller(clf, "predict_proba", X)
388-
389-
if y_type == "binary":
390-
y_pred = self._select_proba_binary(y_pred, clf.classes_)
391-
elif isinstance(y_pred, list):
351+
y_pred = method_caller(clf, "predict_proba", X, pos_label=pos_label)
352+
if isinstance(y_pred, list):
392353
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
393354

394355
if sample_weight is not None:

‎sklearn/metrics/tests/test_score_objects.py

Copy file name to clipboardExpand all lines: sklearn/metrics/tests/test_score_objects.py
+10-5Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -759,13 +759,18 @@ def test_multimetric_scorer_calls_method_once(
759759
X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0])
760760

761761
mock_est = Mock()
762-
fit_func = Mock(return_value=mock_est)
763-
predict_func = Mock(return_value=y)
762+
mock_est._estimator_type = "classifier"
763+
fit_func = Mock(return_value=mock_est, name="fit")
764+
fit_func.__name__ = "fit"
765+
predict_func = Mock(return_value=y, name="predict")
766+
predict_func.__name__ = "predict"
764767

765768
pos_proba = np.random.rand(X.shape[0])
766769
proba = np.c_[1 - pos_proba, pos_proba]
767-
predict_proba_func = Mock(return_value=proba)
768-
decision_function_func = Mock(return_value=pos_proba)
770+
predict_proba_func = Mock(return_value=proba, name="predict_proba")
771+
predict_proba_func.__name__ = "predict_proba"
772+
decision_function_func = Mock(return_value=pos_proba, name="decision_function")
773+
decision_function_func.__name__ = "decision_function"
769774

770775
mock_est.fit = fit_func
771776
mock_est.predict = predict_func
@@ -961,7 +966,7 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
961966
n_classes=3, n_informative=3, n_samples=20, random_state=0
962967
)
963968
lr = Perceptron().fit(X, y)
964-
msg = "'Perceptron' object has no attribute 'predict_proba'"
969+
msg = "Perceptron has none of the following attributes: predict_proba."
965970
with pytest.raises(AttributeError, match=msg):
966971
scorer(lr, X, y)
967972

‎sklearn/utils/_response.py

Copy file name to clipboardExpand all lines: sklearn/utils/_response.py
-10Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ def _get_response_values(
1919
The response values are predictions, one scalar value for each sample in X
2020
that depends on the specific choice of `response_method`.
2121
22-
This helper only accepts multiclass classifiers with the `predict` response
23-
method.
24-
2522
If `estimator` is a binary classifier, also return the label for the
2623
effective positive class.
2724
@@ -75,15 +72,8 @@ def _get_response_values(
7572
if is_classifier(estimator):
7673
prediction_method = _check_response_method(estimator, response_method)
7774
classes = estimator.classes_
78-
7975
target_type = "binary" if len(classes) <= 2 else "multiclass"
8076

81-
if target_type == "multiclass" and prediction_method.__name__ != "predict":
82-
raise ValueError(
83-
"With a multiclass estimator, the response method should be "
84-
f"predict, got {prediction_method.__name__} instead."
85-
)
86-
8777
if pos_label is not None and pos_label not in classes.tolist():
8878
raise ValueError(
8979
f"pos_label={pos_label} is not a valid label: It should be "

‎sklearn/utils/tests/test_response.py

Copy file name to clipboardExpand all lines: sklearn/utils/tests/test_response.py
+25-20Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
LinearRegression,
77
LogisticRegression,
88
)
9-
from sklearn.svm import SVC
9+
from sklearn.preprocessing import scale
1010
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
1111
from sklearn.utils._mocking import _MockEstimatorOnOffPrediction
1212
from sklearn.utils._testing import assert_allclose, assert_array_equal
@@ -15,6 +15,8 @@
1515

1616

1717
X, y = load_iris(return_X_y=True)
18+
# scale the data to avoid ConvergenceWarning with LogisticRegression
19+
X = scale(X, copy=False)
1820
X_binary, y_binary = X[:100], y[:100]
1921

2022

@@ -29,25 +31,6 @@ def test_get_response_values_regressor_error(response_method):
2931
_get_response_values(my_estimator, X, response_method=response_method)
3032

3133

32-
@pytest.mark.parametrize(
33-
"estimator, response_method",
34-
[
35-
(DecisionTreeClassifier(), "predict_proba"),
36-
(SVC(), "decision_function"),
37-
],
38-
)
39-
def test_get_response_values_error_multiclass_classifier(estimator, response_method):
40-
"""Check that we raise an error with multiclass classifier and requesting
41-
response values different from `predict`."""
42-
X, y = make_classification(
43-
n_samples=10, n_clusters_per_class=1, n_classes=3, random_state=0
44-
)
45-
classifier = estimator.fit(X, y)
46-
err_msg = "With a multiclass estimator, the response method should be predict"
47-
with pytest.raises(ValueError, match=err_msg):
48-
_get_response_values(classifier, X, response_method=response_method)
49-
50-
5134
def test_get_response_values_regressor():
5235
"""Check the behaviour of `_get_response_values` with regressor."""
5336
X, y = make_regression(n_samples=10, random_state=0)
@@ -227,3 +210,25 @@ def test_get_response_decision_function():
227210
)
228211
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1)
229212
assert pos_label == 0
213+
214+
215+
@pytest.mark.parametrize(
216+
"estimator, response_method",
217+
[
218+
(DecisionTreeClassifier(max_depth=2, random_state=0), "predict_proba"),
219+
(LogisticRegression(), "decision_function"),
220+
],
221+
)
222+
def test_get_response_values_multiclass(estimator, response_method):
223+
"""Check that we can call `_get_response_values` with a multiclass estimator.
224+
It should return the predictions untouched.
225+
"""
226+
estimator.fit(X, y)
227+
predictions, pos_label = _get_response_values(
228+
estimator, X, response_method=response_method
229+
)
230+
231+
assert pos_label is None
232+
assert predictions.shape == (X.shape[0], len(estimator.classes_))
233+
if response_method == "predict_proba":
234+
assert np.logical_and(predictions >= 0, predictions <= 1).all()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.