Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 16bd7ca

Browse filesBrowse files
committed
Pass predict attributes to last estimator in pipeline
Fixes #9293 by passing the attributes provided in `predict` to the last estimator.
1 parent a08555a commit 16bd7ca
Copy full SHA for 16bd7ca

File tree

2 files changed

+18
-2
lines changed
Filter options

2 files changed

+18
-2
lines changed

‎sklearn/pipeline.py

Copy file name to clipboardExpand all lines: sklearn/pipeline.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def fit_transform(self, X, y=None, **fit_params):
296296
return last_step.fit(Xt, y, **fit_params).transform(Xt)
297297

298298
@if_delegate_has_method(delegate='_final_estimator')
299-
def predict(self, X):
299+
def predict(self, X, **predict_params):
300300
"""Apply transforms to the data, and predict with the final estimator
301301
302302
Parameters
@@ -313,7 +313,7 @@ def predict(self, X):
313313
for name, transform in self.steps[:-1]:
314314
if transform is not None:
315315
Xt = transform.transform(Xt)
316-
return self.steps[-1][-1].predict(Xt)
316+
return self.steps[-1][-1].predict(Xt, **predict_params)
317317

318318
@if_delegate_has_method(delegate='_final_estimator')
319319
def fit_predict(self, X, y=None, **fit_params):

‎sklearn/tests/test_pipeline.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_pipeline.py
+16Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ def fit(self, X, y):
141141
self.timestamp_ = time.time()
142142
return self
143143

144+
class DummyEstimatorWithParams(BaseEstimator):
145+
"""Mock classifier that takes params on predict"""
146+
147+
def fit(self, X, y):
148+
return self
149+
150+
def predict(self, X, got_attribute=False):
151+
self.got_attribute = got_attribute
152+
return self
144153

145154
def test_pipeline_init():
146155
# Test the various init parameters of the pipeline.
@@ -383,6 +392,13 @@ def test_fit_predict_with_intermediate_fit_params():
383392
assert_true(pipe.named_steps['clf'].successful)
384393
assert_false('should_succeed' in pipe.named_steps['transf'].fit_params)
385394

395+
def test_predict_with_predict_params():
396+
# tests that Pipeline passes predict_params to the final estimator
397+
# when predict is invoked
398+
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorWithParams())])
399+
pipe.predict(X=None, got_attribute=True)
400+
401+
assert_true(pipe.named_steps['clf'].got_attribute)
386402

387403
def test_feature_union():
388404
# basic sanity check for feature union

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.