Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 80e1d27

Browse filesBrowse files
committed
Enforce n_folds >= 2 for k-fold cross-validation
1 parent 7bc5d1a commit 80e1d27
Copy full SHA for 80e1d27

File tree

Expand file treeCollapse file tree

3 files changed

+19
-9
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+19
-9
lines changed

‎doc/whats_new.rst

Copy file name to clipboardExpand all lines: doc/whats_new.rst
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ API changes summary
129129
- Sparse matrix support in :class:`sklearn.decomposition.RandomizedPCA`
130130
is now deprecated in favor of the new ``TruncatedSVD``.
131131

132+
- :class:`cross_valiation.KFold` and
133+
:class:`cross_valiation.StratifiedKFold` now enforce `n_folds >= 2`
134+
otherwise a ``ValueError`` is raised. By `Olivier Grisel`_.
132135

133136
.. _changes_0_13_1:
134137

‎sklearn/cross_validation.py

Copy file name to clipboardExpand all lines: sklearn/cross_validation.py
+10-7Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,15 @@ def __init__(self, n, n_folds, indices, k=None):
244244
raise ValueError("n_folds must be an integer")
245245
self.n_folds = n_folds = int(n_folds)
246246

247-
if n_folds <= 0:
248-
raise ValueError("Cannot have number of folds below 1.")
247+
if n_folds <= 1:
248+
raise ValueError(
249+
"k-fold cross validation requires at least one"
250+
" train / test split by setting n_folds=2 or more,"
251+
" got n_folds=%d.".format(n_folds))
249252
if n_folds > self.n:
250-
raise ValueError("Cannot have number of folds n_folds=%d greater "
251-
"than the number of samples: %d."
252-
% (self.n_folds, self.n))
253+
raise ValueError(
254+
("Cannot have number of folds n_folds={0} greater"
255+
"than the number of samples: {1}.").format(n_folds, n))
253256

254257

255258
class KFold(_BaseKFold):
@@ -267,7 +270,7 @@ class KFold(_BaseKFold):
267270
Total number of elements.
268271
269272
n_folds : int, default=3
270-
Number of folds.
273+
Number of folds. Must be at least 2.
271274
272275
indices : boolean, optional (default True)
273276
Return train/test split as arrays of indices, rather than a boolean
@@ -355,7 +358,7 @@ class StratifiedKFold(_BaseKFold):
355358
Samples to split in K folds.
356359
357360
n_folds : int, default=3
358-
Number of folds.
361+
Number of folds. Must be at least 2.
359362
360363
indices : boolean, optional (default True)
361364
Return train/test split as arrays of indices, rather than a boolean

‎sklearn/tests/test_cross_validation.py

Copy file name to clipboardExpand all lines: sklearn/tests/test_cross_validation.py
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,18 @@ def test_kfold_valueerrors():
107107
# a characteristic of the code and not a behavior
108108
assert_true("The least populated class" in str(w[0]))
109109

110-
# Error when number of folds is <= 0
110+
# Error when number of folds is <= 1
111111
assert_raises(ValueError, cval.KFold, 2, 0)
112+
assert_raises(ValueError, cval.KFold, 2, 1)
113+
assert_raises(ValueError, cval.StratifiedKFold, y, 0)
114+
assert_raises(ValueError, cval.StratifiedKFold, y, 1)
112115

113116
# When n is not integer:
114-
assert_raises(ValueError, cval.KFold, 2.5, 1)
117+
assert_raises(ValueError, cval.KFold, 2.5, 2)
115118

116119
# When n_folds is not integer:
117120
assert_raises(ValueError, cval.KFold, 5, 1.5)
121+
assert_raises(ValueError, cval.StratifiedKFold, y, 1.5)
118122

119123

120124
def test_kfold_indices():

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.