Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1411232

Browse filesBrowse files
authored
API Adds predict_params for Pipeline proba delegates (#19790)
1 parent f479269 commit 1411232
Copy full SHA for 1411232

File tree

Expand file treeCollapse file tree

3 files changed

+38
-8
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+38
-8
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,14 @@ Changelog
266266
Use ``var_`` instead.
267267
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.
268268

269+
:mod:`sklearn.pipeline`
270+
.......................
271+
272+
- |API| The `predict_proba` and `predict_log_proba` methods of the
273+
:class:`Pipeline` class now support passing prediction kwargs to
274+
the final estimator.
275+
:pr:`19790` by :user:`Christopher Flynn <crflynn>`.
276+
269277
:mod:`sklearn.preprocessing`
270278
............................
271279

‎sklearn/pipeline.py

Copy file name to clipboardExpand all lines: sklearn/pipeline.py
+14-4Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def fit_predict(self, X, y=None, **fit_params):
456456
return y_pred
457457

458458
@if_delegate_has_method(delegate='_final_estimator')
459-
def predict_proba(self, X):
459+
def predict_proba(self, X, **predict_proba_params):
460460
"""Apply transforms, and predict_proba of the final estimator
461461
462462
Parameters
@@ -465,14 +465,18 @@ def predict_proba(self, X):
465465
Data to predict on. Must fulfill input requirements of first step
466466
of the pipeline.
467467
468+
**predict_proba_params : dict of string -> object
469+
Parameters to the ``predict_proba`` called at the end of all
470+
transformations in the pipeline.
471+
468472
Returns
469473
-------
470474
y_proba : array-like of shape (n_samples, n_classes)
471475
"""
472476
Xt = X
473477
for _, name, transform in self._iter(with_final=False):
474478
Xt = transform.transform(Xt)
475-
return self.steps[-1][-1].predict_proba(Xt)
479+
return self.steps[-1][-1].predict_proba(Xt, **predict_proba_params)
476480

477481
@if_delegate_has_method(delegate='_final_estimator')
478482
def decision_function(self, X):
@@ -513,7 +517,7 @@ def score_samples(self, X):
513517
return self.steps[-1][-1].score_samples(Xt)
514518

515519
@if_delegate_has_method(delegate='_final_estimator')
516-
def predict_log_proba(self, X):
520+
def predict_log_proba(self, X, **predict_log_proba_params):
517521
"""Apply transforms, and predict_log_proba of the final estimator
518522
519523
Parameters
@@ -522,14 +526,20 @@ def predict_log_proba(self, X):
522526
Data to predict on. Must fulfill input requirements of first step
523527
of the pipeline.
524528
529+
**predict_log_proba_params : dict of string -> object
530+
Parameters to the ``predict_log_proba`` called at the end of all
531+
transformations in the pipeline.
532+
525533
Returns
526534
-------
527535
y_score : array-like of shape (n_samples, n_classes)
528536
"""
529537
Xt = X
530538
for _, name, transform in self._iter(with_final=False):
531539
Xt = transform.transform(Xt)
532-
return self.steps[-1][-1].predict_log_proba(Xt)
540+
return self.steps[-1][-1].predict_log_proba(
541+
Xt, **predict_log_proba_params
542+
)
533543

534544
@property
535545
def transform(self):

‎sklearn/tests/test_pipeline.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_pipeline.py
+16-4Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,14 @@ def predict(self, X, got_attribute=False):
159159
self.got_attribute = got_attribute
160160
return self
161161

162+
def predict_proba(self, X, got_attribute=False):
163+
self.got_attribute = got_attribute
164+
return self
165+
166+
def predict_log_proba(self, X, got_attribute=False):
167+
self.got_attribute = got_attribute
168+
return self
169+
162170

163171
def test_pipeline_init():
164172
# Test the various init parameters of the pipeline.
@@ -448,12 +456,16 @@ def test_fit_predict_with_intermediate_fit_params():
448456
assert 'should_succeed' not in pipe.named_steps['transf'].fit_params
449457

450458

451-
def test_predict_with_predict_params():
452-
# tests that Pipeline passes predict_params to the final estimator
453-
# when predict is invoked
459+
@pytest.mark.parametrize("method_name", [
460+
"predict", "predict_proba", "predict_log_proba"
461+
])
462+
def test_predict_methods_with_predict_params(method_name):
463+
# tests that Pipeline passes predict_* to the final estimator
464+
# when predict_* is invoked
454465
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
455466
pipe.fit(None, None)
456-
pipe.predict(X=None, got_attribute=True)
467+
method = getattr(pipe, method_name)
468+
method(X=None, got_attribute=True)
457469

458470
assert pipe.named_steps['clf'].got_attribute
459471

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.