Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 72844cd

Browse filesBrowse files
authored
TST improve metadata routing tests (#29226)
1 parent 5ced13c commit 72844cd
Copy full SHA for 72844cd

File tree

Expand file treeCollapse file tree

9 files changed

+160
-92
lines changed
Filter options
Expand file treeCollapse file tree

9 files changed

+160
-92
lines changed

‎sklearn/compose/tests/test_column_transformer.py

Copy file name to clipboardExpand all lines: sklearn/compose/tests/test_column_transformer.py
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2640,15 +2640,19 @@ def test_metadata_routing_for_column_transformer(method):
26402640
)
26412641

26422642
if method == "transform":
2643-
trs.fit(X, y)
2643+
trs.fit(X, y, sample_weight=sample_weight, metadata=metadata)
26442644
trs.transform(X, sample_weight=sample_weight, metadata=metadata)
26452645
else:
26462646
getattr(trs, method)(X, y, sample_weight=sample_weight, metadata=metadata)
26472647

26482648
assert len(registry)
26492649
for _trs in registry:
26502650
check_recorded_metadata(
2651-
obj=_trs, method=method, sample_weight=sample_weight, metadata=metadata
2651+
obj=_trs,
2652+
method=method,
2653+
parent=method,
2654+
sample_weight=sample_weight,
2655+
metadata=metadata,
26522656
)
26532657

26542658

‎sklearn/ensemble/tests/test_stacking.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_stacking.py
+10-2Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -973,13 +973,21 @@ def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop, prop_v
973973
assert len(registry)
974974
for sub_est in registry:
975975
check_recorded_metadata(
976-
obj=sub_est, method="fit", split_params=(prop), **{prop: prop_value}
976+
obj=sub_est,
977+
method="fit",
978+
parent="fit",
979+
split_params=(prop),
980+
**{prop: prop_value},
977981
)
978982
# access final_estimator:
979983
registry = est.final_estimator_.registry
980984
assert len(registry)
981985
check_recorded_metadata(
982-
obj=registry[-1], method="predict", split_params=(prop), **{prop: prop_value}
986+
obj=registry[-1],
987+
method="predict",
988+
parent="predict",
989+
split_params=(prop),
990+
**{prop: prop_value},
983991
)
984992

985993

‎sklearn/ensemble/tests/test_voting.py

Copy file name to clipboardExpand all lines: sklearn/ensemble/tests/test_voting.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def test_metadata_routing_for_voting_estimators(Estimator, Child, prop):
759759
registry = estimator[1].registry
760760
assert len(registry)
761761
for sub_est in registry:
762-
check_recorded_metadata(obj=sub_est, method="fit", **kwargs)
762+
check_recorded_metadata(obj=sub_est, method="fit", parent="fit", **kwargs)
763763

764764

765765
@pytest.mark.usefixtures("enable_slep006")

‎sklearn/model_selection/tests/test_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_search.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,6 +2614,7 @@ def test_multi_metric_search_forwards_metadata(SearchCV, param_search):
26142614
check_recorded_metadata(
26152615
obj=_scorer,
26162616
method="score",
2617+
parent="_score",
26172618
split_params=("sample_weight", "metadata"),
26182619
sample_weight=score_weights,
26192620
metadata=score_metadata,

‎sklearn/model_selection/tests/test_validation.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_validation.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,6 +2601,7 @@ def test_validation_functions_routing(func):
26012601
check_recorded_metadata(
26022602
obj=_scorer,
26032603
method="score",
2604+
parent=func.__name__,
26042605
split_params=("sample_weight", "metadata"),
26052606
sample_weight=score_weights,
26062607
metadata=score_metadata,
@@ -2611,6 +2612,7 @@ def test_validation_functions_routing(func):
26112612
check_recorded_metadata(
26122613
obj=_splitter,
26132614
method="split",
2615+
parent=func.__name__,
26142616
groups=split_groups,
26152617
metadata=split_metadata,
26162618
)
@@ -2620,6 +2622,7 @@ def test_validation_functions_routing(func):
26202622
check_recorded_metadata(
26212623
obj=_estimator,
26222624
method="fit",
2625+
parent=func.__name__,
26232626
split_params=("sample_weight", "metadata"),
26242627
sample_weight=fit_sample_weight,
26252628
metadata=fit_metadata,
@@ -2657,6 +2660,7 @@ def test_learning_curve_exploit_incremental_learning_routing():
26572660
check_recorded_metadata(
26582661
obj=_estimator,
26592662
method="partial_fit",
2663+
parent="learning_curve",
26602664
split_params=("sample_weight", "metadata"),
26612665
sample_weight=fit_sample_weight,
26622666
metadata=fit_metadata,

‎sklearn/tests/metadata_routing_common.py

Copy file name to clipboardExpand all lines: sklearn/tests/metadata_routing_common.py
+65-51Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
from collections import defaultdict
13
from functools import partial
24

35
import numpy as np
@@ -25,55 +27,69 @@
2527
from sklearn.utils.multiclass import _check_partial_fit_first_call
2628

2729

28-
def record_metadata(obj, method, record_default=True, **kwargs):
29-
"""Utility function to store passed metadata to a method.
30+
def record_metadata(obj, record_default=True, **kwargs):
31+
"""Utility function to store passed metadata to a method of obj.
3032
3133
If record_default is False, kwargs whose values are "default" are skipped.
3234
This is so that checks on keyword arguments whose default was not changed
3335
are skipped.
3436
3537
"""
38+
stack = inspect.stack()
39+
callee = stack[1].function
40+
caller = stack[2].function
3641
if not hasattr(obj, "_records"):
37-
obj._records = {}
42+
obj._records = defaultdict(lambda: defaultdict(list))
3843
if not record_default:
3944
kwargs = {
4045
key: val
4146
for key, val in kwargs.items()
4247
if not isinstance(val, str) or (val != "default")
4348
}
44-
obj._records[method] = kwargs
49+
obj._records[callee][caller].append(kwargs)
4550

4651

47-
def check_recorded_metadata(obj, method, split_params=tuple(), **kwargs):
52+
def check_recorded_metadata(obj, method, parent, split_params=tuple(), **kwargs):
4853
"""Check whether the expected metadata is passed to the object's method.
4954
5055
Parameters
5156
----------
5257
obj : estimator object
5358
sub-estimator to check routed params for
5459
method : str
55-
sub-estimator's method where metadata is routed to
60+
sub-estimator's method where metadata is routed to, or otherwise in
61+
the context of metadata routing referred to as 'callee'
62+
parent : str
63+
the parent method which should have called `method`, or otherwise in
64+
the context of metadata routing referred to as 'caller'
5665
split_params : tuple, default=empty
5766
specifies any parameters which are to be checked as being a subset
5867
of the original values
5968
**kwargs : dict
6069
passed metadata
6170
"""
62-
records = getattr(obj, "_records", dict()).get(method, dict())
63-
assert set(kwargs.keys()) == set(
64-
records.keys()
65-
), f"Expected {kwargs.keys()} vs {records.keys()}"
66-
for key, value in kwargs.items():
67-
recorded_value = records[key]
68-
# The following condition is used to check for any specified parameters
69-
# being a subset of the original values
70-
if key in split_params and recorded_value is not None:
71-
assert np.isin(recorded_value, value).all()
72-
else:
73-
if isinstance(recorded_value, np.ndarray):
74-
assert_array_equal(recorded_value, value)
71+
all_records = (
72+
getattr(obj, "_records", dict()).get(method, dict()).get(parent, list())
73+
)
74+
for record in all_records:
75+
# first check that the names of the metadata passed are the same as
76+
# expected. The names are stored as keys in `record`.
77+
assert set(kwargs.keys()) == set(
78+
record.keys()
79+
), f"Expected {kwargs.keys()} vs {record.keys()}"
80+
for key, value in kwargs.items():
81+
recorded_value = record[key]
82+
# The following condition is used to check for any specified parameters
83+
# being a subset of the original values
84+
if key in split_params and recorded_value is not None:
85+
assert np.isin(recorded_value, value).all()
7586
else:
76-
assert recorded_value is value, f"Expected {recorded_value} vs {value}"
87+
if isinstance(recorded_value, np.ndarray):
88+
assert_array_equal(recorded_value, value)
89+
else:
90+
assert (
91+
recorded_value is value
92+
), f"Expected {recorded_value} vs {value}. Method: {method}"
7793

7894

7995
record_metadata_not_default = partial(record_metadata, record_default=False)
@@ -151,7 +167,7 @@ def partial_fit(self, X, y, sample_weight="default", metadata="default"):
151167
self.registry.append(self)
152168

153169
record_metadata_not_default(
154-
self, "partial_fit", sample_weight=sample_weight, metadata=metadata
170+
self, sample_weight=sample_weight, metadata=metadata
155171
)
156172
return self
157173

@@ -160,19 +176,19 @@ def fit(self, X, y, sample_weight="default", metadata="default"):
160176
self.registry.append(self)
161177

162178
record_metadata_not_default(
163-
self, "fit", sample_weight=sample_weight, metadata=metadata
179+
self, sample_weight=sample_weight, metadata=metadata
164180
)
165181
return self
166182

167183
def predict(self, X, y=None, sample_weight="default", metadata="default"):
168184
record_metadata_not_default(
169-
self, "predict", sample_weight=sample_weight, metadata=metadata
185+
self, sample_weight=sample_weight, metadata=metadata
170186
)
171187
return np.zeros(shape=(len(X),))
172188

173189
def score(self, X, y, sample_weight="default", metadata="default"):
174190
record_metadata_not_default(
175-
self, "score", sample_weight=sample_weight, metadata=metadata
191+
self, sample_weight=sample_weight, metadata=metadata
176192
)
177193
return 1
178194

@@ -240,7 +256,7 @@ def partial_fit(
240256
self.registry.append(self)
241257

242258
record_metadata_not_default(
243-
self, "partial_fit", sample_weight=sample_weight, metadata=metadata
259+
self, sample_weight=sample_weight, metadata=metadata
244260
)
245261
_check_partial_fit_first_call(self, classes)
246262
return self
@@ -250,15 +266,15 @@ def fit(self, X, y, sample_weight="default", metadata="default"):
250266
self.registry.append(self)
251267

252268
record_metadata_not_default(
253-
self, "fit", sample_weight=sample_weight, metadata=metadata
269+
self, sample_weight=sample_weight, metadata=metadata
254270
)
255271

256272
self.classes_ = np.unique(y)
257273
return self
258274

259275
def predict(self, X, sample_weight="default", metadata="default"):
260276
record_metadata_not_default(
261-
self, "predict", sample_weight=sample_weight, metadata=metadata
277+
self, sample_weight=sample_weight, metadata=metadata
262278
)
263279
y_score = np.empty(shape=(len(X),), dtype="int8")
264280
y_score[len(X) // 2 :] = 0
@@ -267,7 +283,7 @@ def predict(self, X, sample_weight="default", metadata="default"):
267283

268284
def predict_proba(self, X, sample_weight="default", metadata="default"):
269285
record_metadata_not_default(
270-
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
286+
self, sample_weight=sample_weight, metadata=metadata
271287
)
272288
y_proba = np.empty(shape=(len(X), 2))
273289
y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0])
@@ -279,13 +295,13 @@ def predict_log_proba(self, X, sample_weight="default", metadata="default"):
279295

280296
# uncomment when needed
281297
# record_metadata_not_default(
282-
# self, "predict_log_proba", sample_weight=sample_weight, metadata=metadata
298+
# self, sample_weight=sample_weight, metadata=metadata
283299
# )
284300
# return np.zeros(shape=(len(X), 2))
285301

286302
def decision_function(self, X, sample_weight="default", metadata="default"):
287303
record_metadata_not_default(
288-
self, "predict_proba", sample_weight=sample_weight, metadata=metadata
304+
self, sample_weight=sample_weight, metadata=metadata
289305
)
290306
y_score = np.empty(shape=(len(X),))
291307
y_score[len(X) // 2 :] = 0
@@ -295,7 +311,7 @@ def decision_function(self, X, sample_weight="default", metadata="default"):
295311
# uncomment when needed
296312
# def score(self, X, y, sample_weight="default", metadata="default"):
297313
# record_metadata_not_default(
298-
# self, "score", sample_weight=sample_weight, metadata=metadata
314+
# self, sample_weight=sample_weight, metadata=metadata
299315
# )
300316
# return 1
301317

@@ -315,38 +331,38 @@ class ConsumingTransformer(TransformerMixin, BaseEstimator):
315331
def __init__(self, registry=None):
316332
self.registry = registry
317333

318-
def fit(self, X, y=None, sample_weight=None, metadata=None):
334+
def fit(self, X, y=None, sample_weight="default", metadata="default"):
319335
if self.registry is not None:
320336
self.registry.append(self)
321337

322338
record_metadata_not_default(
323-
self, "fit", sample_weight=sample_weight, metadata=metadata
339+
self, sample_weight=sample_weight, metadata=metadata
324340
)
325341
return self
326342

327-
def transform(self, X, sample_weight=None, metadata=None):
328-
record_metadata(
329-
self, "transform", sample_weight=sample_weight, metadata=metadata
343+
def transform(self, X, sample_weight="default", metadata="default"):
344+
record_metadata_not_default(
345+
self, sample_weight=sample_weight, metadata=metadata
330346
)
331-
return X
347+
return X + 1
332348

333-
def fit_transform(self, X, y, sample_weight=None, metadata=None):
349+
def fit_transform(self, X, y, sample_weight="default", metadata="default"):
334350
# implementing ``fit_transform`` is necessary since
335351
# ``TransformerMixin.fit_transform`` doesn't route any metadata to
336352
# ``transform``, while here we want ``transform`` to receive
337353
# ``sample_weight`` and ``metadata``.
338-
record_metadata(
339-
self, "fit_transform", sample_weight=sample_weight, metadata=metadata
354+
record_metadata_not_default(
355+
self, sample_weight=sample_weight, metadata=metadata
340356
)
341357
return self.fit(X, y, sample_weight=sample_weight, metadata=metadata).transform(
342358
X, sample_weight=sample_weight, metadata=metadata
343359
)
344360

345361
def inverse_transform(self, X, sample_weight=None, metadata=None):
346-
record_metadata(
347-
self, "inverse_transform", sample_weight=sample_weight, metadata=metadata
362+
record_metadata_not_default(
363+
self, sample_weight=sample_weight, metadata=metadata
348364
)
349-
return X
365+
return X - 1
350366

351367

352368
class ConsumingNoFitTransformTransformer(BaseEstimator):
@@ -361,14 +377,12 @@ def fit(self, X, y=None, sample_weight=None, metadata=None):
361377
if self.registry is not None:
362378
self.registry.append(self)
363379

364-
record_metadata(self, "fit", sample_weight=sample_weight, metadata=metadata)
380+
record_metadata(self, sample_weight=sample_weight, metadata=metadata)
365381

366382
return self
367383

368384
def transform(self, X, sample_weight=None, metadata=None):
369-
record_metadata(
370-
self, "transform", sample_weight=sample_weight, metadata=metadata
371-
)
385+
record_metadata(self, sample_weight=sample_weight, metadata=metadata)
372386
return X
373387

374388

@@ -383,7 +397,7 @@ def _score(self, method_caller, clf, X, y, **kwargs):
383397
if self.registry is not None:
384398
self.registry.append(self)
385399

386-
record_metadata_not_default(self, "score", **kwargs)
400+
record_metadata_not_default(self, **kwargs)
387401

388402
sample_weight = kwargs.get("sample_weight", None)
389403
return super()._score(method_caller, clf, X, y, sample_weight=sample_weight)
@@ -397,7 +411,7 @@ def split(self, X, y=None, groups="default", metadata="default"):
397411
if self.registry is not None:
398412
self.registry.append(self)
399413

400-
record_metadata_not_default(self, "split", groups=groups, metadata=metadata)
414+
record_metadata_not_default(self, groups=groups, metadata=metadata)
401415

402416
split_index = len(X) // 2
403417
train_indices = list(range(0, split_index))
@@ -445,7 +459,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
445459
if self.registry is not None:
446460
self.registry.append(self)
447461

448-
record_metadata(self, "fit", sample_weight=sample_weight)
462+
record_metadata(self, sample_weight=sample_weight)
449463
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
450464
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
451465
return self
@@ -479,7 +493,7 @@ def fit(self, X, y, sample_weight=None, **kwargs):
479493
if self.registry is not None:
480494
self.registry.append(self)
481495

482-
record_metadata(self, "fit", sample_weight=sample_weight)
496+
record_metadata(self, sample_weight=sample_weight)
483497
params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs)
484498
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
485499
return self

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.