Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5ccfabf

Browse filesBrowse files
TST Change assert from sklearn to pytest style in module linear_model/tests (#19565)
1 parent bfd7b58 commit 5ccfabf
Copy full SHA for 5ccfabf

10 files changed

+133
-66
lines changed

‎sklearn/linear_model/_omp.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/_omp.py
+5-3Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from ..utils.fixes import delayed
2121
from ..model_selection import check_cv
2222

23-
premature = """ Orthogonal matching pursuit ended prematurely due to linear
24-
dependence in the dictionary. The requested precision might not have been met.
25-
"""
23+
premature = (
24+
"Orthogonal matching pursuit ended prematurely due to linear"
25+
" dependence in the dictionary. The requested precision might"
26+
" not have been met."
27+
)
2628

2729

2830
def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True,

‎sklearn/linear_model/tests/test_coordinate_descent.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_coordinate_descent.py
+26-10Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from sklearn.utils._testing import assert_allclose
2020
from sklearn.utils._testing import assert_array_almost_equal
2121
from sklearn.utils._testing import assert_almost_equal
22-
from sklearn.utils._testing import assert_warns
23-
from sklearn.utils._testing import assert_warns_message
2422
from sklearn.utils._testing import ignore_warnings
2523
from sklearn.utils._testing import assert_array_equal
2624
from sklearn.utils._testing import _convert_container
@@ -646,7 +644,13 @@ def test_lasso_alpha_warning():
646644
Y = [-1, 0, 1] # just a straight line
647645

648646
clf = Lasso(alpha=0)
649-
assert_warns(UserWarning, clf.fit, X, Y)
647+
warning_message = (
648+
"With alpha=0, this algorithm does not "
649+
"converge well. You are advised to use the "
650+
"LinearRegression estimator"
651+
)
652+
with pytest.warns(UserWarning, match=warning_message):
653+
clf.fit(X, Y)
650654

651655

652656
def test_lasso_positive_constraint():
@@ -733,7 +737,12 @@ def test_multi_task_lasso_and_enet():
733737
assert_array_almost_equal(clf.coef_[0], clf.coef_[1])
734738

735739
clf = MultiTaskElasticNet(alpha=1.0, tol=1e-8, max_iter=1)
736-
assert_warns_message(ConvergenceWarning, 'did not converge', clf.fit, X, Y)
740+
warning_message = (
741+
"Objective did not converge. You might want to "
742+
"increase the number of iterations."
743+
)
744+
with pytest.warns(ConvergenceWarning, match=warning_message):
745+
clf.fit(X, Y)
737746

738747

739748
def test_lasso_readonly_data():
@@ -1075,11 +1084,13 @@ def test_overrided_gram_matrix():
10751084
X, y, _, _ = build_dataset(n_samples=20, n_features=10)
10761085
Gram = X.T.dot(X)
10771086
clf = ElasticNet(selection='cyclic', tol=1e-8, precompute=Gram)
1078-
assert_warns_message(UserWarning,
1079-
"Gram matrix was provided but X was centered"
1080-
" to fit intercept, "
1081-
"or X was normalized : recomputing Gram matrix.",
1082-
clf.fit, X, y)
1087+
warning_message = (
1088+
"Gram matrix was provided but X was centered"
1089+
" to fit intercept, "
1090+
"or X was normalized : recomputing Gram matrix."
1091+
)
1092+
with pytest.warns(UserWarning, match=warning_message):
1093+
clf.fit(X, y)
10831094

10841095

10851096
@pytest.mark.parametrize('model', [ElasticNet, Lasso])
@@ -1214,7 +1225,12 @@ def test_enet_coordinate_descent(klass, n_classes, kwargs):
12141225
y = np.ones((n_samples, n_classes))
12151226
if klass == Lasso:
12161227
y = y.ravel()
1217-
assert_warns(ConvergenceWarning, clf.fit, X, y)
1228+
warning_message = (
1229+
"Objective did not converge. You might want to"
1230+
" increase the number of iterations."
1231+
)
1232+
with pytest.warns(ConvergenceWarning, match=warning_message):
1233+
clf.fit(X, y)
12181234

12191235

12201236
def test_convergence_warnings():

‎sklearn/linear_model/tests/test_least_angle.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_least_angle.py
+5-2Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from sklearn.utils._testing import assert_array_almost_equal
1111
from sklearn.utils._testing import assert_raises
1212
from sklearn.utils._testing import ignore_warnings
13-
from sklearn.utils._testing import assert_warns
1413
from sklearn.utils._testing import TempMemmap
1514
from sklearn.utils.fixes import np_version, parse_version
1615
from sklearn.exceptions import ConvergenceWarning
@@ -372,7 +371,11 @@ def objective_function(coef):
372371
+ alpha * linalg.norm(coef, 1))
373372

374373
lars = linear_model.LassoLars(alpha=alpha, normalize=False)
375-
assert_warns(ConvergenceWarning, lars.fit, X, y)
374+
warning_message = (
375+
"Regressors in active set degenerate."
376+
)
377+
with pytest.warns(ConvergenceWarning, match=warning_message):
378+
lars.fit(X, y)
376379
lars_coef_ = lars.coef_
377380
lars_obj = objective_function(lars_coef_)
378381

‎sklearn/linear_model/tests/test_logistic.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_logistic.py
+41-29Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
from sklearn.model_selection import cross_val_score
2020
from sklearn.preprocessing import LabelEncoder, StandardScaler
2121
from sklearn.utils import compute_class_weight, _IS_32BIT
22-
from sklearn.utils._testing import assert_warns
2322
from sklearn.utils._testing import ignore_warnings
24-
from sklearn.utils._testing import assert_warns_message
2523
from sklearn.utils import shuffle
2624
from sklearn.linear_model import SGDClassifier
2725
from sklearn.preprocessing import scale
@@ -155,11 +153,13 @@ def test_lr_liblinear_warning():
155153
target = iris.target_names[iris.target]
156154

157155
lr = LogisticRegression(solver='liblinear', n_jobs=2)
158-
assert_warns_message(UserWarning,
159-
"'n_jobs' > 1 does not have any effect when"
160-
" 'solver' is set to 'liblinear'. Got 'n_jobs'"
161-
" = 2.",
162-
lr.fit, iris.data, target)
156+
warning_message = (
157+
"'n_jobs' > 1 does not have any effect when"
158+
" 'solver' is set to 'liblinear'. Got 'n_jobs'"
159+
" = 2."
160+
)
161+
with pytest.warns(UserWarning, match=warning_message):
162+
lr.fit(iris.data, target)
163163

164164

165165
def test_predict_3_classes():
@@ -1188,23 +1188,34 @@ def test_logreg_predict_proba_multinomial():
11881188
assert clf_wrong_loss > clf_multi_loss
11891189

11901190

1191-
def test_max_iter():
1191+
@pytest.mark.parametrize("max_iter", np.arange(1, 5))
1192+
@pytest.mark.parametrize("multi_class", ['ovr', 'multinomial'])
1193+
@pytest.mark.parametrize(
1194+
"solver, message",
1195+
[("newton-cg", "newton-cg failed to converge. Increase the "
1196+
"number of iterations."),
1197+
("liblinear", "Liblinear failed to converge, increase the "
1198+
"number of iterations."),
1199+
("sag", "The max_iter was reached which means the "
1200+
"coef_ did not converge"),
1201+
("saga", "The max_iter was reached which means the "
1202+
"coef_ did not converge"),
1203+
("lbfgs", "lbfgs failed to converge")])
1204+
def test_max_iter(max_iter, multi_class, solver, message):
11921205
# Test that the maximum number of iteration is reached
11931206
X, y_bin = iris.data, iris.target.copy()
11941207
y_bin[y_bin == 2] = 0
11951208

1196-
solvers = ['newton-cg', 'liblinear', 'sag', 'saga', 'lbfgs']
1209+
if solver == 'liblinear' and multi_class == 'multinomial':
1210+
pytest.skip("'multinomial' is unavailable when solver='liblinear'")
1211+
1212+
lr = LogisticRegression(max_iter=max_iter, tol=1e-15,
1213+
multi_class=multi_class,
1214+
random_state=0, solver=solver)
1215+
with pytest.warns(ConvergenceWarning, match=message):
1216+
lr.fit(X, y_bin)
11971217

1198-
for max_iter in range(1, 5):
1199-
for solver in solvers:
1200-
for multi_class in ['ovr', 'multinomial']:
1201-
if solver == 'liblinear' and multi_class == 'multinomial':
1202-
continue
1203-
lr = LogisticRegression(max_iter=max_iter, tol=1e-15,
1204-
multi_class=multi_class,
1205-
random_state=0, solver=solver)
1206-
assert_warns(ConvergenceWarning, lr.fit, X, y_bin)
1207-
assert lr.n_iter_[0] == max_iter
1218+
assert lr.n_iter_[0] == max_iter
12081219

12091220

12101221
@pytest.mark.parametrize('solver',
@@ -1644,12 +1655,11 @@ def test_l1_ratio_param(l1_ratio):
16441655
l1_ratio=l1_ratio).fit(X, Y1)
16451656

16461657
if l1_ratio is not None:
1647-
msg = ("l1_ratio parameter is only used when penalty is 'elasticnet'."
1648-
" Got (penalty=l1)")
1649-
1650-
assert_warns_message(UserWarning, msg,
1651-
LogisticRegression(penalty='l1', solver='saga',
1652-
l1_ratio=l1_ratio).fit, X, Y1)
1658+
msg = (r"l1_ratio parameter is only used when penalty is"
1659+
r" 'elasticnet'\. Got \(penalty=l1\)")
1660+
with pytest.warns(UserWarning, match=msg):
1661+
LogisticRegression(penalty='l1', solver='saga',
1662+
l1_ratio=l1_ratio).fit(X, Y1)
16531663

16541664

16551665
@pytest.mark.parametrize('l1_ratios', ([], [.5, 2], None, 'something_wrong'))
@@ -1664,11 +1674,12 @@ def test_l1_ratios_param(l1_ratios):
16641674
l1_ratios=l1_ratios, cv=2).fit(X, Y1)
16651675

16661676
if l1_ratios is not None:
1667-
msg = ("l1_ratios parameter is only used when penalty is "
1668-
"'elasticnet'. Got (penalty=l1)")
1677+
msg = (r"l1_ratios parameter is only used when penalty"
1678+
r" is 'elasticnet'. Got \(penalty=l1\)")
16691679
function = LogisticRegressionCV(penalty='l1', solver='saga',
16701680
l1_ratios=l1_ratios, cv=2).fit
1671-
assert_warns_message(UserWarning, msg, function, X, Y1)
1681+
with pytest.warns(UserWarning, match=msg):
1682+
function(X, Y1)
16721683

16731684

16741685
@pytest.mark.parametrize('C', np.logspace(-3, 2, 4))
@@ -1769,7 +1780,8 @@ def test_penalty_none(solver):
17691780

17701781
msg = "Setting penalty='none' will ignore the C"
17711782
lr = LogisticRegression(penalty='none', solver=solver, C=4)
1772-
assert_warns_message(UserWarning, msg, lr.fit, X, y)
1783+
with pytest.warns(UserWarning, match=msg):
1784+
lr.fit(X, y)
17731785

17741786
lr_none = LogisticRegression(penalty='none', solver=solver,
17751787
random_state=0)

‎sklearn/linear_model/tests/test_omp.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_omp.py
+18-8Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
# License: BSD 3 clause
33

44
import numpy as np
5+
import pytest
56

67
from sklearn.utils._testing import assert_raises
78
from sklearn.utils._testing import assert_array_equal
89
from sklearn.utils._testing import assert_array_almost_equal
9-
from sklearn.utils._testing import assert_warns
1010
from sklearn.utils._testing import ignore_warnings
1111

1212

@@ -76,12 +76,16 @@ def test_unreachable_accuracy():
7676
assert_array_almost_equal(
7777
orthogonal_mp(X, y, tol=0),
7878
orthogonal_mp(X, y, n_nonzero_coefs=n_features))
79-
80-
assert_array_almost_equal(
81-
assert_warns(RuntimeWarning, orthogonal_mp, X, y, tol=0,
82-
precompute=True),
83-
orthogonal_mp(X, y, precompute=True,
84-
n_nonzero_coefs=n_features))
79+
warning_message = (
80+
"Orthogonal matching pursuit ended prematurely "
81+
"due to linear dependence in the dictionary. "
82+
"The requested precision might not have been met."
83+
)
84+
with pytest.warns(RuntimeWarning, match=warning_message):
85+
assert_array_almost_equal(
86+
orthogonal_mp(X, y, tol=0, precompute=True),
87+
orthogonal_mp(X, y, precompute=True,
88+
n_nonzero_coefs=n_features))
8589

8690

8791
def test_bad_input():
@@ -155,7 +159,13 @@ def test_identical_regressors():
155159
gamma = np.zeros(n_features)
156160
gamma[0] = gamma[1] = 1.
157161
newy = np.dot(newX, gamma)
158-
assert_warns(RuntimeWarning, orthogonal_mp, newX, newy, 2)
162+
warning_message = (
163+
"Orthogonal matching pursuit ended prematurely "
164+
"due to linear dependence in the dictionary. "
165+
"The requested precision might not have been met."
166+
)
167+
with pytest.warns(RuntimeWarning, match=warning_message):
168+
orthogonal_mp(newX, newy, 2)
159169

160170

161171
def test_swapped_regressors():

‎sklearn/linear_model/tests/test_ransac.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_ransac.py
+8-3Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from numpy.testing import assert_array_equal
77

88
from sklearn.utils import check_random_state
9-
from sklearn.utils._testing import assert_warns
109
from sklearn.utils._testing import assert_raises_regexp
1110
from sklearn.utils._testing import assert_allclose
1211
from sklearn.datasets import make_regression
@@ -232,8 +231,14 @@ def is_data_valid(X, y):
232231
is_data_valid=is_data_valid,
233232
max_skips=3,
234233
max_trials=5)
235-
236-
assert_warns(ConvergenceWarning, ransac_estimator.fit, X, y)
234+
warning_message = (
235+
"RANSAC found a valid consensus set but exited "
236+
"early due to skipping more iterations than "
237+
"`max_skips`. See estimator attributes for "
238+
"diagnostics."
239+
)
240+
with pytest.warns(ConvergenceWarning, match=warning_message):
241+
ransac_estimator.fit(X, y)
237242
assert ransac_estimator.n_skips_no_inliers_ == 0
238243
assert ransac_estimator.n_skips_invalid_data_ == 4
239244
assert ransac_estimator.n_skips_invalid_model_ == 0

‎sklearn/linear_model/tests/test_ridge.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_ridge.py
+8-5Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from sklearn.utils._testing import assert_array_almost_equal
1111
from sklearn.utils._testing import assert_array_equal
1212
from sklearn.utils._testing import ignore_warnings
13-
from sklearn.utils._testing import assert_warns
1413

1514
from sklearn.exceptions import ConvergenceWarning
1615

@@ -162,10 +161,14 @@ def test_ridge_regression_convergence_fail():
162161
rng = np.random.RandomState(0)
163162
y = rng.randn(5)
164163
X = rng.randn(5, 10)
165-
166-
assert_warns(ConvergenceWarning, ridge_regression,
167-
X, y, alpha=1.0, solver="sparse_cg",
168-
tol=0., max_iter=None, verbose=1)
164+
warning_message = (
165+
r"sparse_cg did not converge after"
166+
r" [0-9]+ iterations."
167+
)
168+
with pytest.warns(ConvergenceWarning, match=warning_message):
169+
ridge_regression(X, y,
170+
alpha=1.0, solver="sparse_cg",
171+
tol=0., max_iter=None, verbose=1)
169172

170173

171174
def test_ridge_sample_weights():

‎sklearn/linear_model/tests/test_sgd.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_sgd.py
+7-2Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sklearn.utils._testing import assert_almost_equal
1010
from sklearn.utils._testing import assert_array_almost_equal
1111
from sklearn.utils._testing import assert_raises_regexp
12-
from sklearn.utils._testing import assert_warns
1312
from sklearn.utils._testing import ignore_warnings
1413
from sklearn.utils.fixes import parse_version
1514

@@ -1446,7 +1445,13 @@ def test_tol_parameter():
14461445

14471446
# Strict tolerance and small max_iter should trigger a warning
14481447
model_3 = SGDClassifier(max_iter=3, tol=1e-3, random_state=0)
1449-
model_3 = assert_warns(ConvergenceWarning, model_3.fit, X, y)
1448+
warning_message = (
1449+
"Maximum number of iteration reached before "
1450+
"convergence. Consider increasing max_iter to "
1451+
"improve the fit."
1452+
)
1453+
with pytest.warns(ConvergenceWarning, match=warning_message):
1454+
model_3.fit(X, y)
14501455
assert model_3.n_iter_ == 3
14511456

14521457

‎sklearn/linear_model/tests/test_sparse_coordinate_descent.py

Copy file name to clipboardExpand all lines: sklearn/linear_model/tests/test_sparse_coordinate_descent.py
+7-2Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
2+
import pytest
23
import scipy.sparse as sp
34

45
from sklearn.utils._testing import assert_array_almost_equal
56
from sklearn.utils._testing import assert_almost_equal
67

78
from sklearn.utils._testing import ignore_warnings
8-
from sklearn.utils._testing import assert_warns
99
from sklearn.exceptions import ConvergenceWarning
1010

1111
from sklearn.linear_model import Lasso, ElasticNet, LassoCV, ElasticNetCV
@@ -297,4 +297,9 @@ def test_sparse_enet_coordinate_descent():
297297
n_features = 2
298298
X = sp.csc_matrix((n_samples, n_features)) * 1e50
299299
y = np.ones(n_samples)
300-
assert_warns(ConvergenceWarning, clf.fit, X, y)
300+
warning_message = (
301+
"Objective did not converge. You might want "
302+
"to increase the number of iterations."
303+
)
304+
with pytest.warns(ConvergenceWarning, match=warning_message):
305+
clf.fit(X, y)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.