-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
API Adds predict_params for Pipeline proba delegates #19790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @crflynn !
Please add an entry to the change log at doc/whats_new/v1.0.rst
with tag ||. Like the other entries there, please reference this pull request with :pr:
and credit yourself (and other contributors if applicable) with :user:
.
sklearn/pipeline.py
Outdated
@@ -456,7 +456,7 @@ def fit_predict(self, X, y=None, **fit_params): | ||
return y_pred | ||
|
||
@if_delegate_has_method(delegate='_final_estimator') | ||
def predict_proba(self, X): | ||
def predict_proba(self, X, **predict_params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it is longer, but I slightly prefer:
def predict_proba(self, X, **predict_params): | |
def predict_proba(self, X, **predict_proba_params): |
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your suggestion is more explicit, so I've made the changes.
sklearn/pipeline.py
Outdated
@@ -513,7 +517,7 @@ def score_samples(self, X): | ||
return self.steps[-1][-1].score_samples(Xt) | ||
|
||
@if_delegate_has_method(delegate='_final_estimator') | ||
def predict_log_proba(self, X): | ||
def predict_log_proba(self, X, **predict_params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
def predict_log_proba(self, X, **predict_params): | |
def predict_log_proba(self, X, **predict_log_proba_params): |
sklearn/tests/test_pipeline.py
Outdated
@@ -459,6 +467,26 @@ def test_predict_with_predict_params(): | ||
assert pipe.named_steps['clf'].got_attribute | ||
|
||
|
||
def test_predict_proba_with_predict_params(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can update test_predict_with_predict_params
above into the following:
@pytest.mark.parametrize("method_name", [
"predict", "predict_proba", "predict_log_proba"
])
def test_predict_methods_with_predict_params(method_name):
# tests that Pipeline passes predict_* to the final estimator
# when predict_* is invoked
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
pipe.fit(None, None)
method = getattr(pipe, method_name)
method(X=None, got_attribute=True)
assert pipe.named_steps['clf'].got_attribute
This takes advantage of pytest.mark.parametrize
to test all the methods at once.
1242b40
to
dbae1ab
Compare
I assume you mean |API| tag here. If not, I'll adjust it. |
Yes, that is what I meant. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @crflynn |
Reference Issues/PRs
Implements some of the changes discussed in #12006.
What does this implement/fix? Explain your changes.
Extends the
**predict_params
functionality ofPipeline.predict
toPipeline.predict_proba
andPipeline.predict_log_proba
.Any other comments?
As noted in #12006, there is currently no use case for this within sklearn. However, given the extensibility of the library and implementations in lightgbm and xgboost it seems appropriate and provides a consistent signature across predict delegates.
I'd be happy to implement similar changes for the
decision_function
delegate as well if requested.