Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

TST replaces assert_raises* by pytest.raises in model_selection/tests/test_split.py #19585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions 79 sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from itertools import permutations

from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_raises
from sklearn.utils._testing import assert_raises_regexp
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_raise_message
Expand Down Expand Up @@ -206,11 +204,14 @@ def test_kfold_valueerrors():
# classes are less than n_splits.
y = np.array([3, 3, -1, -1, 2])

assert_raises(ValueError, next, skf_3.split(X2, y))
with pytest.raises(ValueError):
next(skf_3.split(X2, y))

# Error when number of folds is <= 1
assert_raises(ValueError, KFold, 0)
assert_raises(ValueError, KFold, 1)
with pytest.raises(ValueError):
KFold(0)
with pytest.raises(ValueError):
KFold(1)
error_string = ("k-fold cross-validation requires at least one"
" train/test split")
assert_raise_message(ValueError, error_string,
Expand All @@ -219,13 +220,18 @@ def test_kfold_valueerrors():
StratifiedKFold, 1)

# When n_splits is not integer:
assert_raises(ValueError, KFold, 1.5)
assert_raises(ValueError, KFold, 2.0)
assert_raises(ValueError, StratifiedKFold, 1.5)
assert_raises(ValueError, StratifiedKFold, 2.0)
with pytest.raises(ValueError):
KFold(1.5)
with pytest.raises(ValueError):
KFold(2.0)
with pytest.raises(ValueError):
StratifiedKFold(1.5)
with pytest.raises(ValueError):
StratifiedKFold(2.0)

# When shuffle is not a bool:
assert_raises(TypeError, KFold, n_splits=4, shuffle=None)
with pytest.raises(TypeError):
KFold(n_splits=4, shuffle=None)


def test_kfold_indices():
Expand Down Expand Up @@ -565,24 +571,25 @@ def test_stratified_shuffle_split_init():
X = np.arange(7)
y = np.asarray([0, 1, 1, 1, 2, 2, 2])
# Check that error is raised if there is a class with only one sample
assert_raises(ValueError, next,
StratifiedShuffleSplit(3, 0.2).split(X, y))
with pytest.raises(ValueError):
next(StratifiedShuffleSplit(3, 0.2).split(X, y))

# Check that error is raised if the test set size is smaller than n_classes
assert_raises(ValueError, next, StratifiedShuffleSplit(3, 2).split(X, y))
with pytest.raises(ValueError):
next(StratifiedShuffleSplit(3, 2).split(X, y))
# Check that error is raised if the train set size is smaller than
# n_classes
assert_raises(ValueError, next,
StratifiedShuffleSplit(3, 3, 2).split(X, y))
with pytest.raises(ValueError):
next(StratifiedShuffleSplit(3, 3, 2).split(X, y))

X = np.arange(9)
y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2])

# Train size or test size too small
assert_raises(ValueError, next,
StratifiedShuffleSplit(train_size=2).split(X, y))
assert_raises(ValueError, next,
StratifiedShuffleSplit(test_size=2).split(X, y))
with pytest.raises(ValueError):
next(StratifiedShuffleSplit(train_size=2).split(X, y))
with pytest.raises(ValueError):
next(StratifiedShuffleSplit(test_size=2).split(X, y))


def test_stratified_shuffle_split_respects_test_size():
Expand Down Expand Up @@ -845,9 +852,9 @@ def test_leave_one_p_group_out():
assert lpgo_1.get_n_splits(groups=np.arange(4)) == 4

# raise ValueError if a `groups` parameter is illegal
with assert_raises(ValueError):
with pytest.raises(ValueError):
logo.get_n_splits(None, None, [0.0, np.nan, 0.0])
with assert_raises(ValueError):
with pytest.raises(ValueError):
lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0])

msg = "The 'groups' parameter should not be None."
Expand Down Expand Up @@ -911,8 +918,10 @@ def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
def test_repeated_cv_value_errors():
# n_repeats is not integer or <= 0
for cv in (RepeatedKFold, RepeatedStratifiedKFold):
assert_raises(ValueError, cv, n_repeats=0)
assert_raises(ValueError, cv, n_repeats=1.5)
with pytest.raises(ValueError):
cv(n_repeats=0)
with pytest.raises(ValueError):
cv(n_repeats=1.5)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -954,7 +963,8 @@ def test_repeated_kfold_determinstic_split():
assert_array_equal(train, [2, 3, 4])
assert_array_equal(test, [0, 1])

assert_raises(StopIteration, next, splits)
with pytest.raises(StopIteration):
next(splits)


def test_get_n_splits_for_repeated_kfold():
Expand Down Expand Up @@ -1002,7 +1012,8 @@ def test_repeated_stratified_kfold_determinstic_split():
assert_array_equal(train, [0, 1, 4])
assert_array_equal(test, [2, 3])

assert_raises(StopIteration, next, splits)
with pytest.raises(StopIteration):
next(splits)


def test_train_test_split_errors():
Expand Down Expand Up @@ -1258,7 +1269,8 @@ def test_check_cv():
cv = check_cv(3, y_multioutput, classifier=True)
np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))

assert_raises(ValueError, check_cv, cv="lolo")
with pytest.raises(ValueError):
check_cv(cv="lolo")


def test_cv_iterable_wrapper():
Expand Down Expand Up @@ -1375,17 +1387,22 @@ def test_group_kfold():
# Should fail if there are more folds than groups
groups = np.array([1, 1, 1, 2, 2])
X = y = np.ones(len(groups))
assert_raises_regexp(ValueError, "Cannot have number of splits.*greater",
next, GroupKFold(n_splits=3).split(X, y, groups))
with pytest.raises(
ValueError,
match="Cannot have number of splits.*greater"
):
next(GroupKFold(n_splits=3).split(X, y, groups))


def test_time_series_cv():
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]

# Should fail if there are more folds than samples
assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
next,
TimeSeriesSplit(n_splits=7).split(X))
with pytest.raises(
ValueError,
match="Cannot have number of folds.*greater"
):
next(TimeSeriesSplit(n_splits=7).split(X))

tscv = TimeSeriesSplit(2)

Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.