Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 70aab36

Browse filesBrowse files
authored
REVERT ENH add the parameter prefit in the FixedThresholdClassifier (#29067) (#30172)
1 parent 004cf9e commit 70aab36
Copy full SHA for 70aab36

File tree

7 files changed

+165
-68
lines changed
Filter options

7 files changed

+165
-68
lines changed

‎doc/modules/classification_threshold.rst

Copy file name to clipboardExpand all lines: doc/modules/classification_threshold.rst
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ Manually setting the decision threshold
144144
The previous sections discussed strategies to find an optimal decision threshold. It is
145145
also possible to manually set the decision threshold using the class
146146
:class:`~sklearn.model_selection.FixedThresholdClassifier`. In case that you don't want
147-
to refit the model when calling `fit`, you can set the parameter `prefit=True`.
147+
to refit the model when calling `fit`, wrap your sub-estimator with a
148+
:class:`~sklearn.frozen.FrozenEstimator` and do
149+
``FixedThresholdClassifier(FrozenEstimator(estimator), ...)``.
148150

149151
Examples
150152
--------

‎doc/whats_new/upcoming_changes/sklearn.model_selection/29067.enhancement.rst

Copy file name to clipboardExpand all lines: doc/whats_new/upcoming_changes/sklearn.model_selection/29067.enhancement.rst
-4Lines changed: 0 additions & 4 deletions
This file was deleted.
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- There is no need to call `fit` on a
2+
:class:`~sklearn.model_selection.FixedThresholdClassifier` if the underlying
3+
estimator is already fitted.
4+
By :user:`Adrin Jalali <adrinjalali>`
+98Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
===================================
3+
Examples of Using `FrozenEstimator`
4+
===================================
5+
6+
This examples showcases some use cases of :class:`~sklearn.frozen.FrozenEstimator`.
7+
8+
:class:`~sklearn.frozen.FrozenEstimator` is a utility class that allows to freeze a
9+
fitted estimator. This is useful, for instance, when we want to pass a fitted estimator
10+
to a meta-estimator, such as :class:`~sklearn.model_selection.FixedThresholdClassifier`
11+
without letting the meta-estimator refit the estimator.
12+
"""
13+
14+
# Authors: The scikit-learn developers
15+
# SPDX-License-Identifier: BSD-3-Clause
16+
17+
# %%
18+
# Setting a decision threshold for a pre-fitted classifier
19+
# --------------------------------------------------------
20+
# Fitted classifiers in scikit-learn use an arbitrary decision threshold to decide
21+
# which class the given sample belongs to. The decision threshold is either `0.0` on the
22+
# value returned by :term:`decision_function`, or `0.5` on the probability returned by
23+
# :term:`predict_proba`.
24+
#
25+
# However, one might want to set a custom decision threshold. We can do this by
26+
# using :class:`~sklearn.model_selection.FixedThresholdClassifier` and wrapping the
27+
# classifier with :class:`~sklearn.frozen.FrozenEstimator`.
28+
from sklearn.datasets import make_classification
29+
from sklearn.frozen import FrozenEstimator
30+
from sklearn.linear_model import LogisticRegression
31+
from sklearn.model_selection import FixedThresholdClassifier, train_test_split
32+
33+
X, y = make_classification(n_samples=1000, random_state=0)
34+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
35+
classifier = LogisticRegression().fit(X_train, y_train)
36+
37+
print(
38+
"Probability estimates for three data points:\n"
39+
f"{classifier.predict_proba(X_test[-3:]).round(3)}"
40+
)
41+
print(
42+
"Predicted class for the same three data points:\n"
43+
f"{classifier.predict(X_test[-3:])}"
44+
)
45+
46+
# %%
47+
# Now imagine you'd want to set a different decision threshold on the probability
48+
# estimates. We can do this by wrapping the classifier with
49+
# :class:`~sklearn.frozen.FrozenEstimator` and passing it to
50+
# :class:`~sklearn.model_selection.FixedThresholdClassifier`.
51+
52+
threshold_classifier = FixedThresholdClassifier(
53+
estimator=FrozenEstimator(classifier), threshold=0.9
54+
)
55+
56+
# %%
57+
# Note that in the above piece of code, calling `fit` on
58+
# :class:`~sklearn.model_selection.FixedThresholdClassifier` does not refit the
59+
# underlying classifier.
60+
#
61+
# Now, let's see how the predictions changed with respect to the probability
62+
# threshold.
63+
print(
64+
"Probability estimates for three data points with FixedThresholdClassifier:\n"
65+
f"{threshold_classifier.predict_proba(X_test[-3:]).round(3)}"
66+
)
67+
print(
68+
"Predicted class for the same three data points with FixedThresholdClassifier:\n"
69+
f"{threshold_classifier.predict(X_test[-3:])}"
70+
)
71+
72+
# %%
73+
# We see that the probability estimates stay the same, but since a different decision
74+
# threshold is used, the predicted classes are different.
75+
#
76+
# Please refer to
77+
# :ref:`sphx_glr_auto_examples_model_selection_plot_cost_sensitive_learning.py`
78+
# to learn about cost-sensitive learning and decision threshold tuning.
79+
80+
# %%
81+
# Calibration of a pre-fitted classifier
82+
# --------------------------------------
83+
# You can use :class:`~sklearn.frozen.FrozenEstimator` to calibrate a pre-fitted
84+
# classifier using :class:`~sklearn.calibration.CalibratedClassifierCV`.
85+
from sklearn.calibration import CalibratedClassifierCV
86+
from sklearn.metrics import brier_score_loss
87+
88+
calibrated_classifier = CalibratedClassifierCV(
89+
estimator=FrozenEstimator(classifier)
90+
).fit(X_train, y_train)
91+
92+
prob_pos_clf = classifier.predict_proba(X_test)[:, 1]
93+
clf_score = brier_score_loss(y_test, prob_pos_clf)
94+
print(f"No calibration: {clf_score:.3f}")
95+
96+
prob_pos_calibrated = calibrated_classifier.predict_proba(X_test)[:, 1]
97+
calibrated_score = brier_score_loss(y_test, prob_pos_calibrated)
98+
print(f"With calibration: {calibrated_score:.3f}")

‎examples/model_selection/plot_cost_sensitive_learning.py

Copy file name to clipboardExpand all lines: examples/model_selection/plot_cost_sensitive_learning.py
+6-3Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -660,15 +660,18 @@ def business_metric(y_true, y_pred, amount):
660660
#
661661
# The class :class:`~sklearn.model_selection.FixedThresholdClassifier` allows us to
662662
# manually set the decision threshold. At prediction time, it behave as the previous
663-
# tuned model but no search is performed during the fitting process.
663+
# tuned model but no search is performed during the fitting process. Note that here
664+
# we use :class:`~sklearn.frozen.FrozenEstimator` to wrap the predictive model to
665+
# avoid any refitting.
664666
#
665667
# Here, we will reuse the decision threshold found in the previous section to create a
666668
# new model and check that it gives the same results.
669+
from sklearn.frozen import FrozenEstimator
667670
from sklearn.model_selection import FixedThresholdClassifier
668671

669672
model_fixed_threshold = FixedThresholdClassifier(
670-
estimator=model, threshold=tuned_model.best_threshold_, prefit=True
671-
).fit(data_train, target_train)
673+
estimator=FrozenEstimator(model), threshold=tuned_model.best_threshold_
674+
)
672675

673676
# %%
674677
business_score = business_scorer(

‎sklearn/model_selection/_classification_threshold.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_classification_threshold.py
+34-25Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@
4343
from ._split import StratifiedShuffleSplit, check_cv
4444

4545

46+
def _check_is_fitted(estimator):
47+
try:
48+
check_is_fitted(estimator.estimator)
49+
except NotFittedError:
50+
check_is_fitted(estimator, "estimator_")
51+
52+
4653
def _estimator_has(attr):
4754
"""Check if we can delegate a method to the underlying estimator.
4855
@@ -170,8 +177,9 @@ def predict_proba(self, X):
170177
probabilities : ndarray of shape (n_samples, n_classes)
171178
The class probabilities of the input samples.
172179
"""
173-
check_is_fitted(self, "estimator_")
174-
return self.estimator_.predict_proba(X)
180+
_check_is_fitted(self)
181+
estimator = getattr(self, "estimator_", self.estimator)
182+
return estimator.predict_proba(X)
175183

176184
@available_if(_estimator_has("predict_log_proba"))
177185
def predict_log_proba(self, X):
@@ -188,8 +196,9 @@ def predict_log_proba(self, X):
188196
log_probabilities : ndarray of shape (n_samples, n_classes)
189197
The logarithm class probabilities of the input samples.
190198
"""
191-
check_is_fitted(self, "estimator_")
192-
return self.estimator_.predict_log_proba(X)
199+
_check_is_fitted(self)
200+
estimator = getattr(self, "estimator_", self.estimator)
201+
return estimator.predict_log_proba(X)
193202

194203
@available_if(_estimator_has("decision_function"))
195204
def decision_function(self, X):
@@ -206,8 +215,9 @@ def decision_function(self, X):
206215
decisions : ndarray of shape (n_samples,)
207216
The decision function computed the fitted estimator.
208217
"""
209-
check_is_fitted(self, "estimator_")
210-
return self.estimator_.decision_function(X)
218+
_check_is_fitted(self)
219+
estimator = getattr(self, "estimator_", self.estimator)
220+
return estimator.decision_function(X)
211221

212222
def __sklearn_tags__(self):
213223
tags = super().__sklearn_tags__()
@@ -264,13 +274,6 @@ class FixedThresholdClassifier(BaseThresholdClassifier):
264274
If the method is not implemented by the classifier, it will raise an
265275
error.
266276
267-
prefit : bool, default=False
268-
Whether a pre-fitted model is expected to be passed into the constructor
269-
directly or not. If `True`, `estimator` must be a fitted estimator. If `False`,
270-
`estimator` is fitted and updated by calling `fit`.
271-
272-
.. versionadded:: 1.6
273-
274277
Attributes
275278
----------
276279
estimator_ : estimator instance
@@ -322,7 +325,6 @@ class FixedThresholdClassifier(BaseThresholdClassifier):
322325
**BaseThresholdClassifier._parameter_constraints,
323326
"threshold": [StrOptions({"auto"}), Real],
324327
"pos_label": [Real, str, "boolean", None],
325-
"prefit": ["boolean"],
326328
}
327329

328330
def __init__(
@@ -332,12 +334,22 @@ def __init__(
332334
threshold="auto",
333335
pos_label=None,
334336
response_method="auto",
335-
prefit=False,
336337
):
337338
super().__init__(estimator=estimator, response_method=response_method)
338339
self.pos_label = pos_label
339340
self.threshold = threshold
340-
self.prefit = prefit
341+
342+
@property
343+
def classes_(self):
344+
if estimator := getattr(self, "estimator_", None):
345+
return estimator.classes_
346+
try:
347+
check_is_fitted(self.estimator)
348+
return self.estimator.classes_
349+
except NotFittedError:
350+
raise AttributeError(
351+
"The underlying estimator is not fitted yet."
352+
) from NotFittedError
341353

342354
def _fit(self, X, y, **params):
343355
"""Fit the classifier.
@@ -360,13 +372,7 @@ def _fit(self, X, y, **params):
360372
Returns an instance of self.
361373
"""
362374
routed_params = process_routing(self, "fit", **params)
363-
if self.prefit:
364-
check_is_fitted(self.estimator)
365-
self.estimator_ = self.estimator
366-
else:
367-
self.estimator_ = clone(self.estimator).fit(
368-
X, y, **routed_params.estimator.fit
369-
)
375+
self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
370376
return self
371377

372378
def predict(self, X):
@@ -382,9 +388,12 @@ def predict(self, X):
382388
class_labels : ndarray of shape (n_samples,)
383389
The predicted class.
384390
"""
385-
check_is_fitted(self, "estimator_")
391+
_check_is_fitted(self)
392+
393+
estimator = getattr(self, "estimator_", self.estimator)
394+
386395
y_score, _, response_method_used = _get_response_values_binary(
387-
self.estimator_,
396+
estimator,
388397
X,
389398
self._get_response_method(),
390399
pos_label=self.pos_label,

‎sklearn/model_selection/tests/test_classification_threshold.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_classification_threshold.py
+20-35Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
from sklearn import config_context
5-
from sklearn.base import BaseEstimator, ClassifierMixin, clone
5+
from sklearn.base import clone
66
from sklearn.datasets import (
77
load_breast_cancer,
88
load_iris,
@@ -593,41 +593,26 @@ def test_fixed_threshold_classifier_metadata_routing():
593593
assert_allclose(classifier_default_threshold.estimator_.coef_, classifier.coef_)
594594

595595

596-
class ClassifierLoggingFit(ClassifierMixin, BaseEstimator):
597-
"""Classifier that logs the number of `fit` calls."""
598-
599-
def __init__(self, fit_calls=0):
600-
self.fit_calls = fit_calls
601-
602-
def fit(self, X, y, **fit_params):
603-
self.fit_calls += 1
604-
self.is_fitted_ = True
605-
return self
606-
607-
def predict_proba(self, X):
608-
return np.ones((X.shape[0], 2), np.float64) # pragma: nocover
609-
610-
611-
def test_fixed_threshold_classifier_prefit():
612-
"""Check the behaviour of the `FixedThresholdClassifier` with the `prefit`
613-
parameter."""
596+
@pytest.mark.parametrize(
597+
"method", ["predict_proba", "decision_function", "predict", "predict_log_proba"]
598+
)
599+
def test_fixed_threshold_classifier_fitted_estimator(method):
600+
"""Check that if the underlying estimator is already fitted, no fit is required."""
614601
X, y = make_classification(random_state=0)
602+
classifier = LogisticRegression().fit(X, y)
603+
fixed_threshold_classifier = FixedThresholdClassifier(estimator=classifier)
604+
# This should not raise an error
605+
getattr(fixed_threshold_classifier, method)(X)
615606

616-
estimator = ClassifierLoggingFit()
617-
model = FixedThresholdClassifier(estimator=estimator, prefit=True)
618-
with pytest.raises(NotFittedError):
619-
model.fit(X, y)
620607

621-
# check that we don't clone the classifier when `prefit=True`.
622-
estimator.fit(X, y)
623-
model.fit(X, y)
624-
assert estimator.fit_calls == 1
625-
assert model.estimator_ is estimator
608+
def test_fixed_threshold_classifier_classes_():
609+
"""Check that the classes_ attribute is properly set."""
610+
X, y = make_classification(random_state=0)
611+
with pytest.raises(
612+
AttributeError, match="The underlying estimator is not fitted yet."
613+
):
614+
FixedThresholdClassifier(estimator=LogisticRegression()).classes_
626615

627-
# check that we clone the classifier when `prefit=False`.
628-
estimator = ClassifierLoggingFit()
629-
model = FixedThresholdClassifier(estimator=estimator, prefit=False)
630-
model.fit(X, y)
631-
assert estimator.fit_calls == 0
632-
assert model.estimator_.fit_calls == 1
633-
assert model.estimator_ is not estimator
616+
classifier = LogisticRegression().fit(X, y)
617+
fixed_threshold_classifier = FixedThresholdClassifier(estimator=classifier)
618+
assert_array_equal(fixed_threshold_classifier.classes_, classifier.classes_)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.