1
+ import inspect
2
+ from collections import defaultdict
1
3
from functools import partial
2
4
3
5
import numpy as np
25
27
from sklearn .utils .multiclass import _check_partial_fit_first_call
26
28
27
29
28
- def record_metadata (obj , method , record_default = True , ** kwargs ):
29
- """Utility function to store passed metadata to a method.
30
+ def record_metadata (obj , record_default = True , ** kwargs ):
31
+ """Utility function to store passed metadata to a method of obj .
30
32
31
33
If record_default is False, kwargs whose values are "default" are skipped.
32
34
This is so that checks on keyword arguments whose default was not changed
33
35
are skipped.
34
36
35
37
"""
38
+ stack = inspect .stack ()
39
+ callee = stack [1 ].function
40
+ caller = stack [2 ].function
36
41
if not hasattr (obj , "_records" ):
37
- obj ._records = {}
42
+ obj ._records = defaultdict ( lambda : defaultdict ( list ))
38
43
if not record_default :
39
44
kwargs = {
40
45
key : val
41
46
for key , val in kwargs .items ()
42
47
if not isinstance (val , str ) or (val != "default" )
43
48
}
44
- obj ._records [method ] = kwargs
49
+ obj ._records [callee ][ caller ]. append ( kwargs )
45
50
46
51
47
- def check_recorded_metadata (obj , method , split_params = tuple (), ** kwargs ):
52
+ def check_recorded_metadata (obj , method , parent , split_params = tuple (), ** kwargs ):
48
53
"""Check whether the expected metadata is passed to the object's method.
49
54
50
55
Parameters
51
56
----------
52
57
obj : estimator object
53
58
sub-estimator to check routed params for
54
59
method : str
55
- sub-estimator's method where metadata is routed to
60
+ sub-estimator's method where metadata is routed to, or otherwise in
61
+ the context of metadata routing referred to as 'callee'
62
+ parent : str
63
+ the parent method which should have called `method`, or otherwise in
64
+ the context of metadata routing referred to as 'caller'
56
65
split_params : tuple, default=empty
57
66
specifies any parameters which are to be checked as being a subset
58
67
of the original values
59
68
**kwargs : dict
60
69
passed metadata
61
70
"""
62
- records = getattr (obj , "_records" , dict ()).get (method , dict ())
63
- assert set (kwargs .keys ()) == set (
64
- records .keys ()
65
- ), f"Expected { kwargs .keys ()} vs { records .keys ()} "
66
- for key , value in kwargs .items ():
67
- recorded_value = records [key ]
68
- # The following condition is used to check for any specified parameters
69
- # being a subset of the original values
70
- if key in split_params and recorded_value is not None :
71
- assert np .isin (recorded_value , value ).all ()
72
- else :
73
- if isinstance (recorded_value , np .ndarray ):
74
- assert_array_equal (recorded_value , value )
71
+ all_records = (
72
+ getattr (obj , "_records" , dict ()).get (method , dict ()).get (parent , list ())
73
+ )
74
+ for record in all_records :
75
+ # first check that the names of the metadata passed are the same as
76
+ # expected. The names are stored as keys in `record`.
77
+ assert set (kwargs .keys ()) == set (
78
+ record .keys ()
79
+ ), f"Expected { kwargs .keys ()} vs { record .keys ()} "
80
+ for key , value in kwargs .items ():
81
+ recorded_value = record [key ]
82
+ # The following condition is used to check for any specified parameters
83
+ # being a subset of the original values
84
+ if key in split_params and recorded_value is not None :
85
+ assert np .isin (recorded_value , value ).all ()
75
86
else :
76
- assert recorded_value is value , f"Expected { recorded_value } vs { value } "
87
+ if isinstance (recorded_value , np .ndarray ):
88
+ assert_array_equal (recorded_value , value )
89
+ else :
90
+ assert (
91
+ recorded_value is value
92
+ ), f"Expected { recorded_value } vs { value } . Method: { method } "
77
93
78
94
79
95
record_metadata_not_default = partial (record_metadata , record_default = False )
@@ -151,7 +167,7 @@ def partial_fit(self, X, y, sample_weight="default", metadata="default"):
151
167
self .registry .append (self )
152
168
153
169
record_metadata_not_default (
154
- self , "partial_fit" , sample_weight = sample_weight , metadata = metadata
170
+ self , sample_weight = sample_weight , metadata = metadata
155
171
)
156
172
return self
157
173
@@ -160,19 +176,19 @@ def fit(self, X, y, sample_weight="default", metadata="default"):
160
176
self .registry .append (self )
161
177
162
178
record_metadata_not_default (
163
- self , "fit" , sample_weight = sample_weight , metadata = metadata
179
+ self , sample_weight = sample_weight , metadata = metadata
164
180
)
165
181
return self
166
182
167
183
def predict (self , X , y = None , sample_weight = "default" , metadata = "default" ):
168
184
record_metadata_not_default (
169
- self , "predict" , sample_weight = sample_weight , metadata = metadata
185
+ self , sample_weight = sample_weight , metadata = metadata
170
186
)
171
187
return np .zeros (shape = (len (X ),))
172
188
173
189
def score (self , X , y , sample_weight = "default" , metadata = "default" ):
174
190
record_metadata_not_default (
175
- self , "score" , sample_weight = sample_weight , metadata = metadata
191
+ self , sample_weight = sample_weight , metadata = metadata
176
192
)
177
193
return 1
178
194
@@ -240,7 +256,7 @@ def partial_fit(
240
256
self .registry .append (self )
241
257
242
258
record_metadata_not_default (
243
- self , "partial_fit" , sample_weight = sample_weight , metadata = metadata
259
+ self , sample_weight = sample_weight , metadata = metadata
244
260
)
245
261
_check_partial_fit_first_call (self , classes )
246
262
return self
@@ -250,15 +266,15 @@ def fit(self, X, y, sample_weight="default", metadata="default"):
250
266
self .registry .append (self )
251
267
252
268
record_metadata_not_default (
253
- self , "fit" , sample_weight = sample_weight , metadata = metadata
269
+ self , sample_weight = sample_weight , metadata = metadata
254
270
)
255
271
256
272
self .classes_ = np .unique (y )
257
273
return self
258
274
259
275
def predict (self , X , sample_weight = "default" , metadata = "default" ):
260
276
record_metadata_not_default (
261
- self , "predict" , sample_weight = sample_weight , metadata = metadata
277
+ self , sample_weight = sample_weight , metadata = metadata
262
278
)
263
279
y_score = np .empty (shape = (len (X ),), dtype = "int8" )
264
280
y_score [len (X ) // 2 :] = 0
@@ -267,7 +283,7 @@ def predict(self, X, sample_weight="default", metadata="default"):
267
283
268
284
def predict_proba (self , X , sample_weight = "default" , metadata = "default" ):
269
285
record_metadata_not_default (
270
- self , "predict_proba" , sample_weight = sample_weight , metadata = metadata
286
+ self , sample_weight = sample_weight , metadata = metadata
271
287
)
272
288
y_proba = np .empty (shape = (len (X ), 2 ))
273
289
y_proba [: len (X ) // 2 , :] = np .asarray ([1.0 , 0.0 ])
@@ -279,13 +295,13 @@ def predict_log_proba(self, X, sample_weight="default", metadata="default"):
279
295
280
296
# uncomment when needed
281
297
# record_metadata_not_default(
282
- # self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata
298
+ # self, sample_weight=sample_weight, metadata=metadata
283
299
# )
284
300
# return np.zeros(shape=(len(X), 2))
285
301
286
302
def decision_function (self , X , sample_weight = "default" , metadata = "default" ):
287
303
record_metadata_not_default (
288
- self , "predict_proba" , sample_weight = sample_weight , metadata = metadata
304
+ self , sample_weight = sample_weight , metadata = metadata
289
305
)
290
306
y_score = np .empty (shape = (len (X ),))
291
307
y_score [len (X ) // 2 :] = 0
@@ -295,7 +311,7 @@ def decision_function(self, X, sample_weight="default", metadata="default"):
295
311
# uncomment when needed
296
312
# def score(self, X, y, sample_weight="default", metadata="default"):
297
313
# record_metadata_not_default(
298
- # self, "score", sample_weight=sample_weight, metadata=metadata
314
+ # self, sample_weight=sample_weight, metadata=metadata
299
315
# )
300
316
# return 1
301
317
@@ -315,38 +331,38 @@ class ConsumingTransformer(TransformerMixin, BaseEstimator):
315
331
def __init__ (self , registry = None ):
316
332
self .registry = registry
317
333
318
- def fit (self , X , y = None , sample_weight = None , metadata = None ):
334
+ def fit (self , X , y = None , sample_weight = "default" , metadata = "default" ):
319
335
if self .registry is not None :
320
336
self .registry .append (self )
321
337
322
338
record_metadata_not_default (
323
- self , "fit" , sample_weight = sample_weight , metadata = metadata
339
+ self , sample_weight = sample_weight , metadata = metadata
324
340
)
325
341
return self
326
342
327
- def transform (self , X , sample_weight = None , metadata = None ):
328
- record_metadata (
329
- self , "transform" , sample_weight = sample_weight , metadata = metadata
343
+ def transform (self , X , sample_weight = "default" , metadata = "default" ):
344
+ record_metadata_not_default (
345
+ self , sample_weight = sample_weight , metadata = metadata
330
346
)
331
- return X
347
+ return X + 1
332
348
333
- def fit_transform (self , X , y , sample_weight = None , metadata = None ):
349
+ def fit_transform (self , X , y , sample_weight = "default" , metadata = "default" ):
334
350
# implementing ``fit_transform`` is necessary since
335
351
# ``TransformerMixin.fit_transform`` doesn't route any metadata to
336
352
# ``transform``, while here we want ``transform`` to receive
337
353
# ``sample_weight`` and ``metadata``.
338
- record_metadata (
339
- self , "fit_transform" , sample_weight = sample_weight , metadata = metadata
354
+ record_metadata_not_default (
355
+ self , sample_weight = sample_weight , metadata = metadata
340
356
)
341
357
return self .fit (X , y , sample_weight = sample_weight , metadata = metadata ).transform (
342
358
X , sample_weight = sample_weight , metadata = metadata
343
359
)
344
360
345
361
def inverse_transform (self , X , sample_weight = None , metadata = None ):
346
- record_metadata (
347
- self , "inverse_transform" , sample_weight = sample_weight , metadata = metadata
362
+ record_metadata_not_default (
363
+ self , sample_weight = sample_weight , metadata = metadata
348
364
)
349
- return X
365
+ return X - 1
350
366
351
367
352
368
class ConsumingNoFitTransformTransformer (BaseEstimator ):
@@ -361,14 +377,12 @@ def fit(self, X, y=None, sample_weight=None, metadata=None):
361
377
if self .registry is not None :
362
378
self .registry .append (self )
363
379
364
- record_metadata (self , "fit" , sample_weight = sample_weight , metadata = metadata )
380
+ record_metadata (self , sample_weight = sample_weight , metadata = metadata )
365
381
366
382
return self
367
383
368
384
def transform (self , X , sample_weight = None , metadata = None ):
369
- record_metadata (
370
- self , "transform" , sample_weight = sample_weight , metadata = metadata
371
- )
385
+ record_metadata (self , sample_weight = sample_weight , metadata = metadata )
372
386
return X
373
387
374
388
@@ -383,7 +397,7 @@ def _score(self, method_caller, clf, X, y, **kwargs):
383
397
if self .registry is not None :
384
398
self .registry .append (self )
385
399
386
- record_metadata_not_default (self , "score" , ** kwargs )
400
+ record_metadata_not_default (self , ** kwargs )
387
401
388
402
sample_weight = kwargs .get ("sample_weight" , None )
389
403
return super ()._score (method_caller , clf , X , y , sample_weight = sample_weight )
@@ -397,7 +411,7 @@ def split(self, X, y=None, groups="default", metadata="default"):
397
411
if self .registry is not None :
398
412
self .registry .append (self )
399
413
400
- record_metadata_not_default (self , "split" , groups = groups , metadata = metadata )
414
+ record_metadata_not_default (self , groups = groups , metadata = metadata )
401
415
402
416
split_index = len (X ) // 2
403
417
train_indices = list (range (0 , split_index ))
@@ -445,7 +459,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
445
459
if self .registry is not None :
446
460
self .registry .append (self )
447
461
448
- record_metadata (self , "fit" , sample_weight = sample_weight )
462
+ record_metadata (self , sample_weight = sample_weight )
449
463
params = process_routing (self , "fit" , sample_weight = sample_weight , ** fit_params )
450
464
self .estimator_ = clone (self .estimator ).fit (X , y , ** params .estimator .fit )
451
465
return self
@@ -479,7 +493,7 @@ def fit(self, X, y, sample_weight=None, **kwargs):
479
493
if self .registry is not None :
480
494
self .registry .append (self )
481
495
482
- record_metadata (self , "fit" , sample_weight = sample_weight )
496
+ record_metadata (self , sample_weight = sample_weight )
483
497
params = process_routing (self , "fit" , sample_weight = sample_weight , ** kwargs )
484
498
self .estimator_ = clone (self .estimator ).fit (X , y , ** params .estimator .fit )
485
499
return self
0 commit comments