Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit da3c2d2

Browse filesBrowse files
thomasjpfanogrisel
andauthored
FIX MultiOutputRegressor correctly ducktypes fitted estimators (#19308)
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
1 parent 80e985b commit da3c2d2
Copy full SHA for da3c2d2

File tree

3 files changed

+26
-1
lines changed
Filter options

3 files changed

+26
-1
lines changed

‎doc/whats_new/v0.24.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v0.24.rst
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ Changelog
4848
`'use_encoded_value'` strategies.
4949
:pr:`19234` by `Guillaume Lemaitre <glemaitre>`.
5050

51+
:mod:`sklearn.multioutput`
52+
..........................
53+
54+
- |Fix| :class:`multioutput.MultiOutputRegressor` now works with estimators
55+
that dynamically define `predict` during fitting, such as
56+
:class:`ensemble.StackingRegressor`. :pr:`19308` by `Thomas Fan`_.
57+
5158
:mod:`sklearn.semi_supervised`
5259
..............................
5360

‎sklearn/multioutput.py

Copy file name to clipboardExpand all lines: sklearn/multioutput.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def predict(self, X):
198198
Note: Separate models are generated for each predictor.
199199
"""
200200
check_is_fitted(self)
201-
if not hasattr(self.estimator, "predict"):
201+
if not hasattr(self.estimators_[0], "predict"):
202202
raise ValueError("The base estimator should implement"
203203
" a predict method")
204204

‎sklearn/tests/test_multioutput.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_multioutput.py
+18Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn import datasets
1111
from sklearn.base import clone
1212
from sklearn.datasets import make_classification
13+
from sklearn.datasets import load_linnerud
1314
from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier
1415
from sklearn.exceptions import NotFittedError
1516
from sklearn.linear_model import Lasso
@@ -30,6 +31,7 @@
3031
from sklearn.dummy import DummyRegressor, DummyClassifier
3132
from sklearn.pipeline import make_pipeline
3233
from sklearn.impute import SimpleImputer
34+
from sklearn.ensemble import StackingRegressor
3335

3436

3537
def test_multi_target_regression():
@@ -658,3 +660,19 @@ def test_classifier_chain_tuple_invalid_order():
658660

659661
with pytest.raises(ValueError, match='invalid order'):
660662
chain.fit(X, y)
663+
664+
665+
def test_multioutputregressor_ducktypes_fitted_estimator():
666+
"""Test that MultiOutputRegressor checks the fitted estimator for
667+
predict. Non-regression test for #16549."""
668+
X, y = load_linnerud(return_X_y=True)
669+
stacker = StackingRegressor(
670+
estimators=[("sgd", SGDRegressor(random_state=1))],
671+
final_estimator=Ridge(),
672+
cv=2
673+
)
674+
675+
reg = MultiOutputRegressor(estimator=stacker).fit(X, y)
676+
677+
# Does not raise
678+
reg.predict(X)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.