Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b025a82

Browse filesBrowse files
committed
Pass predict attributes to last estimator in pipeline
Fixes #9293 by passing the attributes provided in `predict` to the last estimator.
1 parent a08555a commit b025a82
Copy full SHA for b025a82

File tree

2 files changed

+29
-2
lines changed
Filter options

2 files changed

+29
-2
lines changed

‎sklearn/pipeline.py

Copy file name to clipboardExpand all lines: sklearn/pipeline.py
+8-2Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def fit_transform(self, X, y=None, **fit_params):
296296
return last_step.fit(Xt, y, **fit_params).transform(Xt)
297297

298298
@if_delegate_has_method(delegate='_final_estimator')
299-
def predict(self, X):
299+
def predict(self, X, **predict_params):
300300
"""Apply transforms to the data, and predict with the final estimator
301301
302302
Parameters
@@ -305,6 +305,12 @@ def predict(self, X):
305305
Data to predict on. Must fulfill input requirements of first step
306306
of the pipeline.
307307
308+
**predict_params : dict of string -> object
309+
Parameters passed to the final ``predict`` in the pipeline. Note
310+
that uncertainties that are generated by the transformations
311+
in the pipeline are not propagated to the final estimator when
312+
this method is called in a pipeline object.
313+
308314
Returns
309315
-------
310316
y_pred : array-like
@@ -313,7 +319,7 @@ def predict(self, X):
313319
for name, transform in self.steps[:-1]:
314320
if transform is not None:
315321
Xt = transform.transform(Xt)
316-
return self.steps[-1][-1].predict(Xt)
322+
return self.steps[-1][-1].predict(Xt, **predict_params)
317323

318324
@if_delegate_has_method(delegate='_final_estimator')
319325
def fit_predict(self, X, y=None, **fit_params):

‎sklearn/tests/test_pipeline.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_pipeline.py
+21Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ def fit(self, X, y):
142142
return self
143143

144144

145+
class DummyEstimatorParams(BaseEstimator):
146+
"""Mock classifier that takes params on predict"""
147+
148+
def fit(self, X, y):
149+
return self
150+
151+
def predict(self, X, got_attribute=False):
152+
self.got_attribute = got_attribute
153+
return self
154+
155+
145156
def test_pipeline_init():
146157
# Test the various init parameters of the pipeline.
147158
assert_raises(TypeError, Pipeline)
@@ -384,6 +395,16 @@ def test_fit_predict_with_intermediate_fit_params():
384395
assert_false('should_succeed' in pipe.named_steps['transf'].fit_params)
385396

386397

398+
def test_predict_with_predict_params():
399+
# tests that Pipeline passes predict_params to the final estimator
400+
# when predict is invoked
401+
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
402+
pipe.fit(None, None)
403+
pipe.predict(X=None, got_attribute=True)
404+
405+
assert_true(pipe.named_steps['clf'].got_attribute)
406+
407+
387408
def test_feature_union():
388409
# basic sanity check for feature union
389410
iris = load_iris()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.