Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b5e55f7

Browse filesBrowse files
authored
TST replace assert_warns* by pytest.warns in model_selection/tests (#19458)
1 parent 43241b1 commit b5e55f7
Copy full SHA for b5e55f7

File tree

4 files changed

+43
-25
lines changed
Filter options

4 files changed

+43
-25
lines changed

‎sklearn/model_selection/_validation.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/_validation.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def _translate_train_sizes(train_sizes, n_max_training_samples):
14731473
if n_ticks > train_sizes_abs.shape[0]:
14741474
warnings.warn("Removed duplicate entries from 'train_sizes'. Number "
14751475
"of ticks will be less than the size of "
1476-
"'train_sizes' %d instead of %d)."
1476+
"'train_sizes': %d instead of %d."
14771477
% (train_sizes_abs.shape[0], n_ticks), RuntimeWarning)
14781478

14791479
return train_sizes_abs

‎sklearn/model_selection/tests/test_search.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_search.py
+14-6Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import pytest
1515

1616
from sklearn.utils._testing import (
17-
assert_warns,
18-
assert_warns_message,
1917
assert_raise_message,
2018
assert_array_equal,
2119
assert_array_almost_equal,
@@ -1433,7 +1431,12 @@ def test_grid_search_failing_classifier():
14331431
# error in this test.
14341432
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
14351433
refit=False, error_score=0.0)
1436-
assert_warns(FitFailedWarning, gs.fit, X, y)
1434+
warning_message = (
1435+
"Estimator fit failed. The score on this train-test partition "
1436+
"for these parameters will be set to 0.0.*."
1437+
)
1438+
with pytest.warns(FitFailedWarning, match=warning_message):
1439+
gs.fit(X, y)
14371440
n_candidates = len(gs.cv_results_['params'])
14381441

14391442
# Ensure that grid scores were set to zero as required for those fits
@@ -1449,7 +1452,12 @@ def get_cand_scores(i):
14491452

14501453
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
14511454
refit=False, error_score=float('nan'))
1452-
assert_warns(FitFailedWarning, gs.fit, X, y)
1455+
warning_message = (
1456+
"Estimator fit failed. The score on this train-test partition "
1457+
"for these parameters will be set to nan."
1458+
)
1459+
with pytest.warns(FitFailedWarning, match=warning_message):
1460+
gs.fit(X, y)
14531461
n_candidates = len(gs.cv_results_['params'])
14541462
assert all(np.all(np.isnan(get_cand_scores(cand_i)))
14551463
for cand_i in range(n_candidates)
@@ -1492,8 +1500,8 @@ def test_parameters_sampler_replacement():
14921500
'than n_iter=%d. Running %d iterations. For '
14931501
'exhaustive searches, use GridSearchCV.'
14941502
% (grid_size, n_iter, grid_size))
1495-
assert_warns_message(UserWarning, expected_warning,
1496-
list, sampler)
1503+
with pytest.warns(UserWarning, match=expected_warning):
1504+
list(sampler)
14971505

14981506
# degenerates to GridSearchCV if n_iter the same as grid_size
14991507
sampler = ParameterSampler(params, n_iter=8)

‎sklearn/model_selection/tests/test_split.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_split.py
+2-3Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from sklearn.utils._testing import assert_raises_regexp
1515
from sklearn.utils._testing import assert_array_almost_equal
1616
from sklearn.utils._testing import assert_array_equal
17-
from sklearn.utils._testing import assert_warns_message
1817
from sklearn.utils._testing import assert_raise_message
1918
from sklearn.utils._testing import ignore_warnings
2019
from sklearn.utils.validation import _num_samples
@@ -193,8 +192,8 @@ def test_kfold_valueerrors():
193192
y = np.array([3, 3, -1, -1, 3])
194193

195194
skf_3 = StratifiedKFold(3)
196-
assert_warns_message(Warning, "The least populated class",
197-
next, skf_3.split(X2, y))
195+
with pytest.warns(Warning, match="The least populated class"):
196+
next(skf_3.split(X2, y))
198197

199198
# Check that despite the warning the folds are still computed even
200199
# though all the classes are not necessarily represented at on each

‎sklearn/model_selection/tests/test_validation.py

Copy file name to clipboardExpand all lines: sklearn/model_selection/tests/test_validation.py
+26-15Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from sklearn.utils._testing import assert_almost_equal
1818
from sklearn.utils._testing import assert_raises
1919
from sklearn.utils._testing import assert_raise_message
20-
from sklearn.utils._testing import assert_warns
21-
from sklearn.utils._testing import assert_warns_message
2220
from sklearn.utils._testing import assert_raises_regex
2321
from sklearn.utils._testing import assert_array_almost_equal
2422
from sklearn.utils._testing import assert_array_equal
@@ -857,13 +855,12 @@ def split(self, X, y=None, groups=None):
857855

858856
X, y = load_iris(return_X_y=True)
859857

860-
warning_message = ('Number of classes in training fold (2) does '
861-
'not match total number of classes (3). '
858+
warning_message = (r'Number of classes in training fold \(2\) does '
859+
r'not match total number of classes \(3\). '
862860
'Results may not be appropriate for your use case.')
863-
assert_warns_message(RuntimeWarning, warning_message,
864-
cross_val_predict,
865-
LogisticRegression(solver="liblinear"),
866-
X, y, method='predict_proba', cv=KFold(2))
861+
with pytest.warns(RuntimeWarning, match=warning_message):
862+
cross_val_predict(LogisticRegression(solver="liblinear"),
863+
X, y, method='predict_proba', cv=KFold(2))
867864

868865

869866
def test_cross_val_predict_decision_function_shape():
@@ -1210,9 +1207,13 @@ def test_learning_curve_remove_duplicate_sample_sizes():
12101207
n_redundant=0, n_classes=2,
12111208
n_clusters_per_class=1, random_state=0)
12121209
estimator = MockImprovingEstimator(2)
1213-
train_sizes, _, _ = assert_warns(
1214-
RuntimeWarning, learning_curve, estimator, X, y, cv=3,
1215-
train_sizes=np.linspace(0.33, 1.0, 3))
1210+
warning_message = (
1211+
"Removed duplicate entries from 'train_sizes'. Number of ticks "
1212+
"will be less than the size of 'train_sizes': 2 instead of 3."
1213+
)
1214+
with pytest.warns(RuntimeWarning, match=warning_message):
1215+
train_sizes, _, _ = learning_curve(
1216+
estimator, X, y, cv=3, train_sizes=np.linspace(0.33, 1.0, 3))
12161217
assert_array_equal(train_sizes, [1, 2])
12171218

12181219

@@ -1753,8 +1754,13 @@ def test_fit_and_score_failing():
17531754
# passing error score to trigger the warning message
17541755
fit_and_score_kwargs = {'error_score': 0}
17551756
# check if the warning message type is as expected
1756-
assert_warns(FitFailedWarning, _fit_and_score, *fit_and_score_args,
1757-
**fit_and_score_kwargs)
1757+
warning_message = (
1758+
"Estimator fit failed. The score on this train-test partition for "
1759+
"these parameters will be set to %f."
1760+
% (fit_and_score_kwargs['error_score'])
1761+
)
1762+
with pytest.warns(FitFailedWarning, match=warning_message):
1763+
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
17581764
# since we're using FailingClassfier, our error will be the following
17591765
error_message = "ValueError: Failing classifier failed as required"
17601766
# the warning message we're expecting to see
@@ -1769,8 +1775,13 @@ def test_warn_trace(msg):
17691775
mtb = split[0] + '\n' + split[-1]
17701776
return warning_message in mtb
17711777
# check traceback is included
1772-
assert_warns_message(FitFailedWarning, test_warn_trace, _fit_and_score,
1773-
*fit_and_score_args, **fit_and_score_kwargs)
1778+
warning_message = (
1779+
"Estimator fit failed. The score on this train-test partition for "
1780+
"these parameters will be set to %f."
1781+
% (fit_and_score_kwargs['error_score'])
1782+
)
1783+
with pytest.warns(FitFailedWarning, match=warning_message):
1784+
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
17741785

17751786
fit_and_score_kwargs = {'error_score': 'raise'}
17761787
# check if exception was raised, with default error_score='raise'

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.