Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 70af34c

Browse filesBrowse files
authored
TST Use pytest.warns in sklearn.semi_supervised tests (#19510)
1 parent 15d2df4 commit 70af34c
Copy full SHA for 70af34c

File tree

1 file changed

+14
-7
lines changed
Filter options

1 file changed

+14
-7
lines changed

‎sklearn/semi_supervised/tests/test_label_propagation.py

Copy file name to clipboardExpand all lines: sklearn/semi_supervised/tests/test_label_propagation.py
+14-7Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import pytest
55

66
from scipy.sparse import issparse
7-
from sklearn.utils._testing import assert_warns
8-
from sklearn.utils._testing import assert_no_warnings
97
from sklearn.semi_supervised import _label_propagation as label_propagation
108
from sklearn.metrics.pairwise import rbf_kernel
119
from sklearn.model_selection import train_test_split
@@ -143,18 +141,25 @@ def test_convergence_warning():
143141
X = np.array([[1., 0.], [0., 1.], [1., 2.5]])
144142
y = np.array([0, 1, -1])
145143
mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=1)
146-
assert_warns(ConvergenceWarning, mdl.fit, X, y)
144+
warn_msg = ('max_iter=1 was reached without convergence.')
145+
with pytest.warns(ConvergenceWarning, match=warn_msg):
146+
mdl.fit(X, y)
147147
assert mdl.n_iter_ == mdl.max_iter
148148

149149
mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=1)
150-
assert_warns(ConvergenceWarning, mdl.fit, X, y)
150+
with pytest.warns(ConvergenceWarning, match=warn_msg):
151+
mdl.fit(X, y)
151152
assert mdl.n_iter_ == mdl.max_iter
152153

153154
mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=500)
154-
assert_no_warnings(mdl.fit, X, y)
155+
with pytest.warns(None) as record:
156+
mdl.fit(X, y)
157+
assert len(record) == 0
155158

156159
mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=500)
157-
assert_no_warnings(mdl.fit, X, y)
160+
with pytest.warns(None) as record:
161+
mdl.fit(X, y)
162+
assert len(record) == 0
158163

159164

160165
@pytest.mark.parametrize("LabelPropagationCls",
@@ -170,7 +175,9 @@ def test_label_propagation_non_zero_normalizer(LabelPropagationCls):
170175
mdl = LabelPropagationCls(kernel='knn',
171176
max_iter=100,
172177
n_neighbors=1)
173-
assert_no_warnings(mdl.fit, X, y)
178+
with pytest.warns(None) as record:
179+
mdl.fit(X, y)
180+
assert len(record) == 0
174181

175182

176183
def test_predict_sparse_callable_kernel():

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.