Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dbae1ab

Browse filesBrowse files
committed
use method name in param arg
1 parent 47f9d2b commit dbae1ab
Copy full SHA for dbae1ab

File tree

Expand file treeCollapse file tree

3 files changed

+24
-30
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+24
-30
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,14 @@ Changelog
249249
Use ``var_`` instead.
250250
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
251251

252+
:mod:`sklearn.pipeline`
253+
.......................
254+
255+
- |API| The `predict_proba` and `predict_log_proba` methods of the
256+
:class:`Pipeline` class now support passing prediction kwargs to
257+
the final estimator.
258+
:pr:`19790` by :user:`Christopher Flynn <crflynn>`.
259+
252260
:mod:`sklearn.preprocessing`
253261
............................
254262

‎sklearn/pipeline.py

Copy file name to clipboardExpand all lines: sklearn/pipeline.py
+8-6Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def fit_predict(self, X, y=None, **fit_params):
456456
return y_pred
457457

458458
@if_delegate_has_method(delegate='_final_estimator')
459-
def predict_proba(self, X, **predict_params):
459+
def predict_proba(self, X, **predict_proba_params):
460460
"""Apply transforms, and predict_proba of the final estimator
461461
462462
Parameters
@@ -465,7 +465,7 @@ def predict_proba(self, X, **predict_params):
465465
Data to predict on. Must fulfill input requirements of first step
466466
of the pipeline.
467467
468-
**predict_params : dict of string -> object
468+
**predict_proba_params : dict of string -> object
469469
Parameters to the ``predict_proba`` called at the end of all
470470
transformations in the pipeline.
471471
@@ -476,7 +476,7 @@ def predict_proba(self, X, **predict_params):
476476
Xt = X
477477
for _, name, transform in self._iter(with_final=False):
478478
Xt = transform.transform(Xt)
479-
return self.steps[-1][-1].predict_proba(Xt, **predict_params)
479+
return self.steps[-1][-1].predict_proba(Xt, **predict_proba_params)
480480

481481
@if_delegate_has_method(delegate='_final_estimator')
482482
def decision_function(self, X):
@@ -517,7 +517,7 @@ def score_samples(self, X):
517517
return self.steps[-1][-1].score_samples(Xt)
518518

519519
@if_delegate_has_method(delegate='_final_estimator')
520-
def predict_log_proba(self, X, **predict_params):
520+
def predict_log_proba(self, X, **predict_log_proba_params):
521521
"""Apply transforms, and predict_log_proba of the final estimator
522522
523523
Parameters
@@ -526,7 +526,7 @@ def predict_log_proba(self, X, **predict_params):
526526
Data to predict on. Must fulfill input requirements of first step
527527
of the pipeline.
528528
529-
**predict_params : dict of string -> object
529+
**predict_log_proba_params : dict of string -> object
530530
Parameters to the ``predict_log_proba`` called at the end of all
531531
transformations in the pipeline.
532532
@@ -537,7 +537,9 @@ def predict_log_proba(self, X, **predict_params):
537537
Xt = X
538538
for _, name, transform in self._iter(with_final=False):
539539
Xt = transform.transform(Xt)
540-
return self.steps[-1][-1].predict_log_proba(Xt, **predict_params)
540+
return self.steps[-1][-1].predict_log_proba(
541+
Xt, **predict_log_proba_params
542+
)
541543

542544
@property
543545
def transform(self):

‎sklearn/tests/test_pipeline.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_pipeline.py
+8-24Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -457,32 +457,16 @@ def test_fit_predict_with_intermediate_fit_params():
457457
assert 'should_succeed' not in pipe.named_steps['transf'].fit_params
458458

459459

460-
def test_predict_with_predict_params():
461-
# tests that Pipeline passes predict_params to the final estimator
462-
# when predict is invoked
460+
@pytest.mark.parametrize("method_name", [
461+
"predict", "predict_proba", "predict_log_proba"
462+
])
463+
def test_predict_methods_with_predict_params(method_name):
464+
# tests that Pipeline passes predict_* to the final estimator
465+
# when predict_* is invoked
463466
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
464467
pipe.fit(None, None)
465-
pipe.predict(X=None, got_attribute=True)
466-
467-
assert pipe.named_steps['clf'].got_attribute
468-
469-
470-
def test_predict_proba_with_predict_params():
471-
# tests that Pipeline passes predict_params to the final estimator
472-
# when predict is invoked
473-
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
474-
pipe.fit(None, None)
475-
pipe.predict_proba(X=None, got_attribute=True)
476-
477-
assert pipe.named_steps['clf'].got_attribute
478-
479-
480-
def test_predict_log_proba_with_predict_params():
481-
# tests that Pipeline passes predict_params to the final estimator
482-
# when predict is invoked
483-
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
484-
pipe.fit(None, None)
485-
pipe.predict_log_proba(X=None, got_attribute=True)
468+
method = getattr(pipe, method_name)
469+
method(X=None, got_attribute=True)
486470

487471
assert pipe.named_steps['clf'].got_attribute
488472

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.