Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 296a631

Browse filesBrowse files
committed
Pass predict attributes to last estimator in pipeline
Fixes #9293 by passing the attributes provided in `predict` to the last estimator.
1 parent 96a2c10 commit 296a631
Copy full SHA for 296a631

File tree

3 files changed

+35
-2
lines changed
Filter options

3 files changed

+35
-2
lines changed

‎doc/whats_new/v0.20.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.20.rst
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ Model evaluation and meta-estimators
156156
group-based CV strategies. :issue:`9085` by :user:`Laurent Direr <ldirer>`
157157
and `Andreas Müller`_.
158158

159+
- A paramenter `predict_params` was added to :class:`pipeline.Pipeline` allowing
160+
that parameters passed to `predict` propagate to the very last estimator of
161+
the pipeline. :issue:`9304` by :user:`Breno Freitas <brenolf>`.
162+
159163
Metrics
160164

161165
- :func:`metrics.roc_auc_score` now supports binary ``y_true`` other than

‎sklearn/pipeline.py

Copy file name to clipboardExpand all lines: sklearn/pipeline.py
+10-2Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def fit_transform(self, X, y=None, **fit_params):
287287
return last_step.fit(Xt, y, **fit_params).transform(Xt)
288288

289289
@if_delegate_has_method(delegate='_final_estimator')
290-
def predict(self, X):
290+
def predict(self, X, **predict_params):
291291
"""Apply transforms to the data, and predict with the final estimator
292292
293293
Parameters
@@ -296,6 +296,14 @@ def predict(self, X):
296296
Data to predict on. Must fulfill input requirements of first step
297297
of the pipeline.
298298
299+
**predict_params : dict of string -> object
300+
Parameters to the ``predict`` called at the end of all
301+
transformations in the pipeline. Note that while this may be
302+
used to return uncertainties from some models with return_std
303+
or return_cov, uncertainties that are generated by the
304+
transformations in the pipeline are not propagated to the
305+
final estimator.
306+
299307
Returns
300308
-------
301309
y_pred : array-like
@@ -304,7 +312,7 @@ def predict(self, X):
304312
for name, transform in self.steps[:-1]:
305313
if transform is not None:
306314
Xt = transform.transform(Xt)
307-
return self.steps[-1][-1].predict(Xt)
315+
return self.steps[-1][-1].predict(Xt, **predict_params)
308316

309317
@if_delegate_has_method(delegate='_final_estimator')
310318
def fit_predict(self, X, y=None, **fit_params):

‎sklearn/tests/test_pipeline.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_pipeline.py
+21Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ def fit(self, X, y):
144144
return self
145145

146146

147+
class DummyEstimatorParams(BaseEstimator):
148+
"""Mock classifier that takes params on predict"""
149+
150+
def fit(self, X, y):
151+
return self
152+
153+
def predict(self, X, got_attribute=False):
154+
self.got_attribute = got_attribute
155+
return self
156+
157+
147158
def test_pipeline_init():
148159
# Test the various init parameters of the pipeline.
149160
assert_raises(TypeError, Pipeline)
@@ -398,6 +409,16 @@ def test_fit_predict_with_intermediate_fit_params():
398409
assert_false('should_succeed' in pipe.named_steps['transf'].fit_params)
399410

400411

412+
def test_predict_with_predict_params():
413+
# tests that Pipeline passes predict_params to the final estimator
414+
# when predict is invoked
415+
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
416+
pipe.fit(None, None)
417+
pipe.predict(X=None, got_attribute=True)
418+
419+
assert_true(pipe.named_steps['clf'].got_attribute)
420+
421+
401422
def test_feature_union():
402423
# basic sanity check for feature union
403424
iris = load_iris()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.