Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bf08cb3

Browse filesBrowse files
MarcoGorellilestevejeremiedbb
authored
Fix a regression in GridSearchCV for parameter grids that have arrays of different sizes as parameter values (#29314)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 3ef8bf5 commit bf08cb3
Copy full SHA for bf08cb3

File tree

Expand file treeCollapse file tree

3 files changed

+172
-41
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+172
-41
lines changed

‎doc/whats_new/v1.5.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.5.rst
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ Changelog
6363
grids that have estimators as parameter values.
6464
:pr:`29179` by :user:`Marco Gorelli<MarcoGorelli>`.
6565

66+
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
67+
grids that have arrays of different sizes as parameter values.
68+
:pr:`29314` by :user:`Marco Gorelli<MarcoGorelli>`.
69+
6670
:mod:`sklearn.tree`
6771
...................
6872

‎sklearn/model_selection/_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_search.py
+52-38Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,56 @@ def check(self):
379379
return check
380380

381381

382+
def _yield_masked_array_for_each_param(candidate_params):
383+
"""
384+
Yield a masked array for each candidate param.
385+
386+
`candidate_params` is a sequence of params which were used in
387+
a `GridSearchCV`. We use masked arrays for the results, as not
388+
all params are necessarily present in each element of
389+
`candidate_params`. For example, if using `GridSearchCV` with
390+
a `SVC` model, then one might search over params like:
391+
392+
- kernel=["rbf"], gamma=[0.1, 1]
393+
- kernel=["poly"], degree=[1, 2]
394+
395+
and then param `'gamma'` would not be present in entries of
396+
`candidate_params` corresponding to `kernel='poly'`.
397+
"""
398+
n_candidates = len(candidate_params)
399+
param_results = defaultdict(dict)
400+
401+
for cand_idx, params in enumerate(candidate_params):
402+
for name, value in params.items():
403+
param_results["param_%s" % name][cand_idx] = value
404+
405+
for key, param_result in param_results.items():
406+
param_list = list(param_result.values())
407+
try:
408+
arr = np.array(param_list)
409+
except ValueError:
410+
# This can happen when param_list contains lists of different
411+
# lengths, for example:
412+
# param_list=[[1], [2, 3]]
413+
arr_dtype = np.dtype(object)
414+
else:
415+
# There are two cases when we don't use the automatically inferred
416+
# dtype when creating the array and we use object instead:
417+
# - string dtype
418+
# - when array.ndim > 1, that means that param_list was something
419+
# like a list of same-size sequences, which gets turned into a
420+
# multi-dimensional array but we want a 1d array
421+
arr_dtype = arr.dtype if arr.dtype.kind != "U" and arr.ndim == 1 else object
422+
423+
# Use one MaskedArray and mask all the places where the param is not
424+
# applicable for that candidate (which may not contain all the params).
425+
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
426+
for index, value in param_result.items():
427+
# Setting the value at an index unmasks that index
428+
ma[index] = value
429+
yield (key, ma)
430+
431+
382432
class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
383433
"""Abstract base class for hyper parameter search with cross-validation."""
384434

@@ -1079,45 +1129,9 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10791129

10801130
_store("fit_time", out["fit_time"])
10811131
_store("score_time", out["score_time"])
1082-
param_results = defaultdict(dict)
1083-
for cand_idx, params in enumerate(candidate_params):
1084-
for name, value in params.items():
1085-
param_results["param_%s" % name][cand_idx] = value
1086-
for key, param_result in param_results.items():
1087-
param_list = list(param_result.values())
1088-
try:
1089-
with warnings.catch_warnings():
1090-
warnings.filterwarnings(
1091-
"ignore",
1092-
message="in the future the `.dtype` attribute",
1093-
category=DeprecationWarning,
1094-
)
1095-
# Warning raised by NumPy 1.20+
1096-
arr_dtype = np.result_type(*param_list)
1097-
except (TypeError, ValueError):
1098-
arr_dtype = np.dtype(object)
1099-
else:
1100-
if any(np.min_scalar_type(x) == object for x in param_list):
1101-
# `np.result_type` might get thrown off by `.dtype` properties
1102-
# (which some estimators have).
1103-
# If finding the result dtype this way would give object,
1104-
# then we use object.
1105-
# https://github.com/scikit-learn/scikit-learn/issues/29157
1106-
arr_dtype = np.dtype(object)
1107-
if len(param_list) == n_candidates and arr_dtype != object:
1108-
# Exclude `object` else the numpy constructor might infer a list of
1109-
# tuples to be a 2d array.
1110-
results[key] = MaskedArray(param_list, mask=False, dtype=arr_dtype)
1111-
else:
1112-
# Use one MaskedArray and mask all the places where the param is not
1113-
# applicable for that candidate (which may not contain all the params).
1114-
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
1115-
for index, value in param_result.items():
1116-
# Setting the value at an index unmasks that index
1117-
ma[index] = value
1118-
results[key] = ma
1119-
11201132
# Store a list of param dicts at the key 'params'
1133+
for param, ma in _yield_masked_array_for_each_param(candidate_params):
1134+
results[param] = ma
11211135
results["params"] = candidate_params
11221136

11231137
test_scores_dict = _normalize_score_results(out["test_scores"])

‎sklearn/model_selection/tests/test_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_search.py
+116-3Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,20 @@
6161
StratifiedShuffleSplit,
6262
train_test_split,
6363
)
64-
from sklearn.model_selection._search import BaseSearchCV
64+
from sklearn.model_selection._search import (
65+
BaseSearchCV,
66+
_yield_masked_array_for_each_param,
67+
)
6568
from sklearn.model_selection.tests.common import OneTimeSplitter
6669
from sklearn.naive_bayes import ComplementNB
6770
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
68-
from sklearn.pipeline import Pipeline
69-
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
71+
from sklearn.pipeline import Pipeline, make_pipeline
72+
from sklearn.preprocessing import (
73+
OneHotEncoder,
74+
OrdinalEncoder,
75+
SplineTransformer,
76+
StandardScaler,
77+
)
7078
from sklearn.svm import SVC, LinearSVC
7179
from sklearn.tests.metadata_routing_common import (
7280
ConsumingScorer,
@@ -2724,6 +2732,37 @@ def test_search_with_estimators_issue_29157():
27242732
assert grid_search.cv_results_["param_enc__enc"].dtype == object
27252733

27262734

2735+
def test_cv_results_multi_size_array():
2736+
"""Check that GridSearchCV works with params that are arrays of different sizes.
2737+
2738+
Non-regression test for #29277.
2739+
"""
2740+
n_features = 10
2741+
X, y = make_classification(n_features=10)
2742+
2743+
spline_reg_pipe = make_pipeline(
2744+
SplineTransformer(extrapolation="periodic"),
2745+
LogisticRegression(),
2746+
)
2747+
2748+
n_knots_list = [n_features * i for i in [10, 11, 12]]
2749+
knots_list = [
2750+
np.linspace(0, np.pi * 2, n_knots).reshape((-1, n_features))
2751+
for n_knots in n_knots_list
2752+
]
2753+
spline_reg_pipe_cv = GridSearchCV(
2754+
estimator=spline_reg_pipe,
2755+
param_grid={
2756+
"splinetransformer__knots": knots_list,
2757+
},
2758+
)
2759+
2760+
spline_reg_pipe_cv.fit(X, y)
2761+
assert (
2762+
spline_reg_pipe_cv.cv_results_["param_splinetransformer__knots"].dtype == object
2763+
)
2764+
2765+
27272766
@pytest.mark.parametrize(
27282767
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
27292768
)
@@ -2747,3 +2786,77 @@ def test_array_api_search_cv_classifier(SearchCV, array_namespace, device, dtype
27472786
)
27482787
searcher.fit(X_xp, y_xp)
27492788
searcher.score(X_xp, y_xp)
2789+
2790+
2791+
# Construct these outside the tests so that the same object is used
2792+
# for both input and `expected`
2793+
one_hot_encoder = OneHotEncoder()
2794+
ordinal_encoder = OrdinalEncoder()
2795+
2796+
# If we construct this directly via `MaskedArray`, the list of tuples
2797+
# gets auto-converted to a 2D array.
2798+
ma_with_tuples = np.ma.MaskedArray(np.empty(2), mask=True, dtype=object)
2799+
ma_with_tuples[0] = (1, 2)
2800+
ma_with_tuples[1] = (3, 4)
2801+
2802+
2803+
@pytest.mark.parametrize(
2804+
("candidate_params", "expected"),
2805+
[
2806+
pytest.param(
2807+
[{"foo": 1}, {"foo": 2}],
2808+
[
2809+
("param_foo", np.ma.MaskedArray(np.array([1, 2]))),
2810+
],
2811+
id="simple numeric, single param",
2812+
),
2813+
pytest.param(
2814+
[{"foo": 1, "bar": 3}, {"foo": 2, "bar": 4}, {"foo": 3}],
2815+
[
2816+
("param_foo", np.ma.MaskedArray(np.array([1, 2, 3]))),
2817+
(
2818+
"param_bar",
2819+
np.ma.MaskedArray(np.array([3, 4, 0]), mask=[False, False, True]),
2820+
),
2821+
],
2822+
id="simple numeric, one param is missing in one round",
2823+
),
2824+
pytest.param(
2825+
[{"foo": [[1], [2], [3]]}, {"foo": [[1], [2]]}],
2826+
[
2827+
(
2828+
"param_foo",
2829+
np.ma.MaskedArray([[[1], [2], [3]], [[1], [2]]], dtype=object),
2830+
),
2831+
],
2832+
id="lists of different lengths",
2833+
),
2834+
pytest.param(
2835+
[{"foo": (1, 2)}, {"foo": (3, 4)}],
2836+
[
2837+
(
2838+
"param_foo",
2839+
ma_with_tuples,
2840+
),
2841+
],
2842+
id="lists tuples",
2843+
),
2844+
pytest.param(
2845+
[{"foo": ordinal_encoder}, {"foo": one_hot_encoder}],
2846+
[
2847+
(
2848+
"param_foo",
2849+
np.ma.MaskedArray([ordinal_encoder, one_hot_encoder], dtype=object),
2850+
),
2851+
],
2852+
id="estimators",
2853+
),
2854+
],
2855+
)
2856+
def test_yield_masked_array_for_each_param(candidate_params, expected):
2857+
result = list(_yield_masked_array_for_each_param(candidate_params))
2858+
for (key, value), (expected_key, expected_value) in zip(result, expected):
2859+
assert key == expected_key
2860+
assert value.dtype == expected_value.dtype
2861+
np.testing.assert_array_equal(value, expected_value)
2862+
np.testing.assert_array_equal(value.mask, expected_value.mask)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.