4
4
import pytest
5
5
6
6
from scipy .sparse import issparse
7
- from sklearn .utils ._testing import assert_warns
8
- from sklearn .utils ._testing import assert_no_warnings
9
7
from sklearn .semi_supervised import _label_propagation as label_propagation
10
8
from sklearn .metrics .pairwise import rbf_kernel
11
9
from sklearn .model_selection import train_test_split
@@ -143,18 +141,25 @@ def test_convergence_warning():
143
141
X = np .array ([[1. , 0. ], [0. , 1. ], [1. , 2.5 ]])
144
142
y = np .array ([0 , 1 , - 1 ])
145
143
mdl = label_propagation .LabelSpreading (kernel = 'rbf' , max_iter = 1 )
146
- assert_warns (ConvergenceWarning , mdl .fit , X , y )
144
+ warn_msg = ('max_iter=1 was reached without convergence.' )
145
+ with pytest .warns (ConvergenceWarning , match = warn_msg ):
146
+ mdl .fit (X , y )
147
147
assert mdl .n_iter_ == mdl .max_iter
148
148
149
149
mdl = label_propagation .LabelPropagation (kernel = 'rbf' , max_iter = 1 )
150
- assert_warns (ConvergenceWarning , mdl .fit , X , y )
150
+ with pytest .warns (ConvergenceWarning , match = warn_msg ):
151
+ mdl .fit (X , y )
151
152
assert mdl .n_iter_ == mdl .max_iter
152
153
153
154
mdl = label_propagation .LabelSpreading (kernel = 'rbf' , max_iter = 500 )
154
- assert_no_warnings (mdl .fit , X , y )
155
+ with pytest .warns (None ) as record :
156
+ mdl .fit (X , y )
157
+ assert len (record ) == 0
155
158
156
159
mdl = label_propagation .LabelPropagation (kernel = 'rbf' , max_iter = 500 )
157
- assert_no_warnings (mdl .fit , X , y )
160
+ with pytest .warns (None ) as record :
161
+ mdl .fit (X , y )
162
+ assert len (record ) == 0
158
163
159
164
160
165
@pytest .mark .parametrize ("LabelPropagationCls" ,
@@ -170,7 +175,9 @@ def test_label_propagation_non_zero_normalizer(LabelPropagationCls):
170
175
mdl = LabelPropagationCls (kernel = 'knn' ,
171
176
max_iter = 100 ,
172
177
n_neighbors = 1 )
173
- assert_no_warnings (mdl .fit , X , y )
178
+ with pytest .warns (None ) as record :
179
+ mdl .fit (X , y )
180
+ assert len (record ) == 0
174
181
175
182
176
183
def test_predict_sparse_callable_kernel ():
0 commit comments