Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f1018c6

Browse filesBrowse files
xiaoyuchaiShawnNicolasHug
authored
FIX BaseSuccessiveHalving class groups support (#19847)
Co-authored-by: Shawn <shawn@mpirica.com> Co-authored-by: Nicolas Hug <nicolashug@fb.com>
1 parent b1d686d commit f1018c6
Copy full SHA for f1018c6

File tree

3 files changed

+42
-1
lines changed
Filter options

3 files changed

+42
-1
lines changed

‎doc/whats_new/v1.0.rst

Copy file name to clipboardExpand all lines: doc/whats_new/v1.0.rst
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,11 @@ Changelog
276276
:pr:`18649` by `Leandro Hermida <hermidalc>` and
277277
`Rodion Martynov <marrodion>`.
278278

279+
- |Fix| The `fit` method of the successive halving parameter search
280+
(:class:`model_selection.HalvingGridSearchCV`, and
281+
:class:`model_selection.HalvingRandomSearchCV`) now correctly handles the
282+
`groups` parameter. :pr:`19847` by :user:`Xiaoyu Chai <xiaoyuchai>`.
283+
279284
:mod:`sklearn.naive_bayes`
280285
..........................
281286

‎sklearn/model_selection/_search_successive_halving.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_search_successive_halving.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def fit(self, X, y=None, groups=None, **fit_params):
210210

211211
self._n_samples_orig = _num_samples(X)
212212

213-
super().fit(X, y=y, groups=None, **fit_params)
213+
super().fit(X, y=y, groups=groups, **fit_params)
214214

215215
# Set best_score_: BaseSearchCV does not set it, as refit is a callable
216216
self.best_score_ = (

‎sklearn/model_selection/tests/test_successive_halving.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_successive_halving.py
+36Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@
77
from sklearn.datasets import make_classification
88
from sklearn.dummy import DummyClassifier
99
from sklearn.experimental import enable_halving_search_cv # noqa
10+
from sklearn.model_selection import StratifiedKFold
11+
from sklearn.model_selection import StratifiedShuffleSplit
12+
from sklearn.model_selection import LeaveOneGroupOut
13+
from sklearn.model_selection import LeavePGroupsOut
14+
from sklearn.model_selection import GroupKFold
15+
from sklearn.model_selection import GroupShuffleSplit
1016
from sklearn.model_selection import HalvingGridSearchCV
1117
from sklearn.model_selection import HalvingRandomSearchCV
1218
from sklearn.model_selection import KFold, ShuffleSplit
19+
from sklearn.svm import LinearSVC
1320
from sklearn.model_selection._search_successive_halving import (
1421
_SubsampleMetaSplitter, _top_k, _refit_callable)
1522

@@ -562,3 +569,32 @@ def set_params(self, **params):
562569

563570
assert (cv_results_df['params'] == passed_params).all()
564571
assert (cv_results_df['n_resources'] == passed_n_samples).all()
572+
573+
574+
@pytest.mark.parametrize('Est', (HalvingGridSearchCV, HalvingRandomSearchCV))
575+
def test_groups_support(Est):
576+
# Check if ValueError (when groups is None) propagates to
577+
# HalvingGridSearchCV and HalvingRandomSearchCV
578+
# And also check if groups is correctly passed to the cv object
579+
rng = np.random.RandomState(0)
580+
581+
X, y = make_classification(n_samples=50, n_classes=2, random_state=0)
582+
groups = rng.randint(0, 3, 50)
583+
584+
clf = LinearSVC(random_state=0)
585+
grid = {'C': [1]}
586+
587+
group_cvs = [LeaveOneGroupOut(), LeavePGroupsOut(2),
588+
GroupKFold(n_splits=3), GroupShuffleSplit(random_state=0)]
589+
error_msg = "The 'groups' parameter should not be None."
590+
for cv in group_cvs:
591+
gs = Est(clf, grid, cv=cv)
592+
with pytest.raises(ValueError, match=error_msg):
593+
gs.fit(X, y)
594+
gs.fit(X, y, groups=groups)
595+
596+
non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit(random_state=0)]
597+
for cv in non_group_cvs:
598+
gs = Est(clf, grid, cv=cv)
599+
# Should not raise an error
600+
gs.fit(X, y)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.