Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8a5d8c3

Browse filesBrowse files
authored
FEA Add metadata routing for TransformedTargetRegressor (#29136)
1 parent 30cf4a0 commit 8a5d8c3
Copy full SHA for 8a5d8c3

File tree

4 files changed

+88
-21
lines changed
Filter options

4 files changed

+88
-21
lines changed

‎doc/metadata_routing.rst

Copy file name to clipboardExpand all lines: doc/metadata_routing.rst
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ Meta-estimators and functions supporting metadata routing:
276276

277277
- :class:`sklearn.calibration.CalibratedClassifierCV`
278278
- :class:`sklearn.compose.ColumnTransformer`
279+
- :class:`sklearn.compose.TransformedTargetRegressor`
279280
- :class:`sklearn.covariance.GraphicalLassoCV`
280281
- :class:`sklearn.ensemble.StackingClassifier`
281282
- :class:`sklearn.ensemble.StackingRegressor`
@@ -316,7 +317,6 @@ Meta-estimators and functions supporting metadata routing:
316317

317318
Meta-estimators and tools not supporting metadata routing yet:
318319

319-
- :class:`sklearn.compose.TransformedTargetRegressor`
320320
- :class:`sklearn.ensemble.AdaBoostClassifier`
321321
- :class:`sklearn.ensemble.AdaBoostRegressor`
322322
- :class:`sklearn.feature_selection.RFE`

‎doc/whats_new/v1.6.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.6.rst
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ more details.
6767
``**fit_params`` to the underlying estimators via their `fit` methods.
6868
:pr:`28701` by :user:`Stefanie Senger <StefanieSenger>`.
6969

70+
- |Feature| :class:`compose.TransformedTargetRegressor` now supports metadata
71+
routing in its `fit` and `predict` methods and routes the corresponding
72+
params to the underlying regressor.
73+
:pr:`29136` by :user:`Omar Salman <OmarManzoor>`.
74+
7075
Dropping official support for PyPy
7176
----------------------------------
7277

‎sklearn/compose/_target.py

Copy file name to clipboardExpand all lines: sklearn/compose/_target.py
+74-19Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88

99
from ..base import BaseEstimator, RegressorMixin, _fit_context, clone
1010
from ..exceptions import NotFittedError
11+
from ..linear_model import LinearRegression
1112
from ..preprocessing import FunctionTransformer
12-
from ..utils import _safe_indexing, check_array
13+
from ..utils import Bunch, _safe_indexing, check_array
14+
from ..utils._metadata_requests import (
15+
MetadataRouter,
16+
MethodMapping,
17+
_routing_enabled,
18+
process_routing,
19+
)
1320
from ..utils._param_validation import HasMethods
1421
from ..utils._tags import _safe_tags
1522
from ..utils.metadata_routing import (
16-
_raise_for_unsupported_routing,
1723
_RoutingNotSupportedMixin,
1824
)
1925
from ..utils.validation import check_is_fitted
@@ -230,15 +236,25 @@ def fit(self, X, y, **fit_params):
230236
Target values.
231237
232238
**fit_params : dict
233-
Parameters passed to the `fit` method of the underlying
234-
regressor.
239+
- If `enable_metadata_routing=False` (default):
240+
241+
Parameters directly passed to the `fit` method of the
242+
underlying regressor.
243+
244+
- If `enable_metadata_routing=True`:
245+
246+
Parameters safely routed to the `fit` method of the
247+
underlying regressor.
248+
249+
.. versionchanged:: 1.6
250+
See :ref:`Metadata Routing User Guide <metadata_routing>` for
251+
more details.
235252
236253
Returns
237254
-------
238255
self : object
239256
Fitted estimator.
240257
"""
241-
_raise_for_unsupported_routing(self, "fit", **fit_params)
242258
if y is None:
243259
raise ValueError(
244260
f"This {self.__class__.__name__} estimator "
@@ -274,14 +290,13 @@ def fit(self, X, y, **fit_params):
274290
if y_trans.ndim == 2 and y_trans.shape[1] == 1:
275291
y_trans = y_trans.squeeze(axis=1)
276292

277-
if self.regressor is None:
278-
from ..linear_model import LinearRegression
279-
280-
self.regressor_ = LinearRegression()
293+
self.regressor_ = self._get_regressor(get_clone=True)
294+
if _routing_enabled():
295+
routed_params = process_routing(self, "fit", **fit_params)
281296
else:
282-
self.regressor_ = clone(self.regressor)
297+
routed_params = Bunch(regressor=Bunch(fit=fit_params))
283298

284-
self.regressor_.fit(X, y_trans, **fit_params)
299+
self.regressor_.fit(X, y_trans, **routed_params.regressor.fit)
285300

286301
if hasattr(self.regressor_, "feature_names_in_"):
287302
self.feature_names_in_ = self.regressor_.feature_names_in_
@@ -300,16 +315,32 @@ def predict(self, X, **predict_params):
300315
Samples.
301316
302317
**predict_params : dict of str -> object
303-
Parameters passed to the `predict` method of the underlying
304-
regressor.
318+
- If `enable_metadata_routing=False` (default):
319+
320+
Parameters directly passed to the `predict` method of the
321+
underlying regressor.
322+
323+
- If `enable_metadata_routing=True`:
324+
325+
Parameters safely routed to the `predict` method of the
326+
underlying regressor.
327+
328+
.. versionchanged:: 1.6
329+
See :ref:`Metadata Routing User Guide <metadata_routing>`
330+
for more details.
305331
306332
Returns
307333
-------
308334
y_hat : ndarray of shape (n_samples,)
309335
Predicted values.
310336
"""
311337
check_is_fitted(self)
312-
pred = self.regressor_.predict(X, **predict_params)
338+
if _routing_enabled():
339+
routed_params = process_routing(self, "predict", **predict_params)
340+
else:
341+
routed_params = Bunch(regressor=Bunch(predict=predict_params))
342+
343+
pred = self.regressor_.predict(X, **routed_params.regressor.predict)
313344
if pred.ndim == 1:
314345
pred_trans = self.transformer_.inverse_transform(pred.reshape(-1, 1))
315346
else:
@@ -324,11 +355,7 @@ def predict(self, X, **predict_params):
324355
return pred_trans
325356

326357
def _more_tags(self):
327-
regressor = self.regressor
328-
if regressor is None:
329-
from ..linear_model import LinearRegression
330-
331-
regressor = LinearRegression()
358+
regressor = self._get_regressor()
332359

333360
return {
334361
"poor_score": True,
@@ -350,3 +377,31 @@ def n_features_in_(self):
350377
) from nfe
351378

352379
return self.regressor_.n_features_in_
380+
381+
def get_metadata_routing(self):
382+
"""Get metadata routing of this object.
383+
384+
Please check :ref:`User Guide <metadata_routing>` on how the routing
385+
mechanism works.
386+
387+
.. versionadded:: 1.6
388+
389+
Returns
390+
-------
391+
routing : MetadataRouter
392+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
393+
routing information.
394+
"""
395+
router = MetadataRouter(owner=self.__class__.__name__).add(
396+
regressor=self._get_regressor(),
397+
method_mapping=MethodMapping()
398+
.add(caller="fit", callee="fit")
399+
.add(caller="predict", callee="predict"),
400+
)
401+
return router
402+
403+
def _get_regressor(self, get_clone=False):
404+
if self.regressor is None:
405+
return LinearRegression()
406+
407+
return clone(self.regressor) if get_clone else self.regressor

‎sklearn/tests/test_metaestimators_metadata_routing.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_metaestimators_metadata_routing.py
+8-1Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,14 @@ def enable_slep006():
382382
"cv_name": "cv",
383383
"cv_routing_methods": ["fit"],
384384
},
385+
{
386+
"metaestimator": TransformedTargetRegressor,
387+
"estimator": "regressor",
388+
"estimator_name": "regressor",
389+
"X": X,
390+
"y": y,
391+
"estimator_routing_methods": ["fit", "predict"],
392+
},
385393
]
386394
"""List containing all metaestimators to be tested and their settings
387395
@@ -427,7 +435,6 @@ def enable_slep006():
427435
RFECV(ConsumingClassifier()),
428436
SelfTrainingClassifier(ConsumingClassifier()),
429437
SequentialFeatureSelector(ConsumingClassifier()),
430-
TransformedTargetRegressor(),
431438
]
432439

433440

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.