8
8
9
9
from ..base import BaseEstimator , RegressorMixin , _fit_context , clone
10
10
from ..exceptions import NotFittedError
11
+ from ..linear_model import LinearRegression
11
12
from ..preprocessing import FunctionTransformer
12
- from ..utils import _safe_indexing , check_array
13
+ from ..utils import Bunch , _safe_indexing , check_array
14
+ from ..utils ._metadata_requests import (
15
+ MetadataRouter ,
16
+ MethodMapping ,
17
+ _routing_enabled ,
18
+ process_routing ,
19
+ )
13
20
from ..utils ._param_validation import HasMethods
14
21
from ..utils ._tags import _safe_tags
15
22
from ..utils .metadata_routing import (
16
- _raise_for_unsupported_routing ,
17
23
_RoutingNotSupportedMixin ,
18
24
)
19
25
from ..utils .validation import check_is_fitted
@@ -230,15 +236,25 @@ def fit(self, X, y, **fit_params):
230
236
Target values.
231
237
232
238
**fit_params : dict
233
- Parameters passed to the `fit` method of the underlying
234
- regressor.
239
+ - If `enable_metadata_routing=False` (default):
240
+
241
+ Parameters directly passed to the `fit` method of the
242
+ underlying regressor.
243
+
244
+ - If `enable_metadata_routing=True`:
245
+
246
+ Parameters safely routed to the `fit` method of the
247
+ underlying regressor.
248
+
249
+ .. versionchanged:: 1.6
250
+ See :ref:`Metadata Routing User Guide <metadata_routing>` for
251
+ more details.
235
252
236
253
Returns
237
254
-------
238
255
self : object
239
256
Fitted estimator.
240
257
"""
241
- _raise_for_unsupported_routing (self , "fit" , ** fit_params )
242
258
if y is None :
243
259
raise ValueError (
244
260
f"This { self .__class__ .__name__ } estimator "
@@ -274,14 +290,13 @@ def fit(self, X, y, **fit_params):
274
290
if y_trans .ndim == 2 and y_trans .shape [1 ] == 1 :
275
291
y_trans = y_trans .squeeze (axis = 1 )
276
292
277
- if self .regressor is None :
278
- from ..linear_model import LinearRegression
279
-
280
- self .regressor_ = LinearRegression ()
293
+ self .regressor_ = self ._get_regressor (get_clone = True )
294
+ if _routing_enabled ():
295
+ routed_params = process_routing (self , "fit" , ** fit_params )
281
296
else :
282
- self . regressor_ = clone ( self . regressor )
297
+ routed_params = Bunch ( regressor = Bunch ( fit = fit_params ) )
283
298
284
- self .regressor_ .fit (X , y_trans , ** fit_params )
299
+ self .regressor_ .fit (X , y_trans , ** routed_params . regressor . fit )
285
300
286
301
if hasattr (self .regressor_ , "feature_names_in_" ):
287
302
self .feature_names_in_ = self .regressor_ .feature_names_in_
@@ -300,16 +315,32 @@ def predict(self, X, **predict_params):
300
315
Samples.
301
316
302
317
**predict_params : dict of str -> object
303
- Parameters passed to the `predict` method of the underlying
304
- regressor.
318
+ - If `enable_metadata_routing=False` (default):
319
+
320
+ Parameters directly passed to the `predict` method of the
321
+ underlying regressor.
322
+
323
+ - If `enable_metadata_routing=True`:
324
+
325
+ Parameters safely routed to the `predict` method of the
326
+ underlying regressor.
327
+
328
+ .. versionchanged:: 1.6
329
+ See :ref:`Metadata Routing User Guide <metadata_routing>`
330
+ for more details.
305
331
306
332
Returns
307
333
-------
308
334
y_hat : ndarray of shape (n_samples,)
309
335
Predicted values.
310
336
"""
311
337
check_is_fitted (self )
312
- pred = self .regressor_ .predict (X , ** predict_params )
338
+ if _routing_enabled ():
339
+ routed_params = process_routing (self , "predict" , ** predict_params )
340
+ else :
341
+ routed_params = Bunch (regressor = Bunch (predict = predict_params ))
342
+
343
+ pred = self .regressor_ .predict (X , ** routed_params .regressor .predict )
313
344
if pred .ndim == 1 :
314
345
pred_trans = self .transformer_ .inverse_transform (pred .reshape (- 1 , 1 ))
315
346
else :
@@ -324,11 +355,7 @@ def predict(self, X, **predict_params):
324
355
return pred_trans
325
356
326
357
def _more_tags (self ):
327
- regressor = self .regressor
328
- if regressor is None :
329
- from ..linear_model import LinearRegression
330
-
331
- regressor = LinearRegression ()
358
+ regressor = self ._get_regressor ()
332
359
333
360
return {
334
361
"poor_score" : True ,
@@ -350,3 +377,31 @@ def n_features_in_(self):
350
377
) from nfe
351
378
352
379
return self .regressor_ .n_features_in_
380
+
381
+ def get_metadata_routing (self ):
382
+ """Get metadata routing of this object.
383
+
384
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
385
+ mechanism works.
386
+
387
+ .. versionadded:: 1.6
388
+
389
+ Returns
390
+ -------
391
+ routing : MetadataRouter
392
+ A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
393
+ routing information.
394
+ """
395
+ router = MetadataRouter (owner = self .__class__ .__name__ ).add (
396
+ regressor = self ._get_regressor (),
397
+ method_mapping = MethodMapping ()
398
+ .add (caller = "fit" , callee = "fit" )
399
+ .add (caller = "predict" , callee = "predict" ),
400
+ )
401
+ return router
402
+
403
+ def _get_regressor (self , get_clone = False ):
404
+ if self .regressor is None :
405
+ return LinearRegression ()
406
+
407
+ return clone (self .regressor ) if get_clone else self .regressor
0 commit comments