6
6
import sys
7
7
import numpy as np
8
8
9
- from numpy .testing import (assert_equal , assert_array_equal ,
10
- assert_allclose , suppress_warnings )
9
+ from numpy .testing import (assert_array_equal , assert_allclose ,
10
+ suppress_warnings )
11
11
import pytest
12
12
13
13
@@ -197,8 +197,8 @@ def callback(x):
197
197
198
198
x , info = solver (A , b , x0 = x0 , tol = tol , maxiter = 1 , callback = callback )
199
199
200
- assert_equal ( len (residuals ), 1 )
201
- assert_equal ( info , 1 )
200
+ assert len (residuals ) == 1
201
+ assert info == 1
202
202
203
203
204
204
def test_maxiter ():
@@ -232,7 +232,7 @@ def check_convergence(solver, case):
232
232
233
233
assert_array_equal (x0 , 0 * b ) # ensure that x0 is not overwritten
234
234
if solver not in case .nonconvergence :
235
- assert_equal ( info , 0 )
235
+ assert info == 0
236
236
assert_normclose (A @ x , b , tol = tol )
237
237
else :
238
238
assert info != 0
@@ -274,15 +274,15 @@ def identity(b, which=None):
274
274
x , info = solver (A , b , M1 = precond , M2 = precond , x0 = x0 , tol = tol )
275
275
else :
276
276
x , info = solver (A , b , M = precond , x0 = x0 , tol = tol )
277
- assert_equal ( info , 0 )
277
+ assert info == 0
278
278
assert_normclose (A @ x , b , tol )
279
279
280
280
A = aslinearoperator (A )
281
281
A .psolve = identity
282
282
A .rpsolve = identity
283
283
284
284
x , info = solver (A , b , x0 = x0 , tol = tol )
285
- assert_equal ( info , 0 )
285
+ assert info == 0
286
286
assert_normclose (A @ x , b , tol = tol )
287
287
288
288
@@ -333,7 +333,7 @@ def rmatvec(b):
333
333
matvec_count = [0 ]
334
334
x , info = solver (A , b , M = precond , x0 = x0 , tol = tol )
335
335
336
- assert_equal ( info , 0 )
336
+ assert info == 0
337
337
assert_normclose (case .A @ x , b , tol )
338
338
339
339
# Solution should be nearly instant
@@ -365,7 +365,7 @@ def _check_reentrancy(solver, is_reentrant):
365
365
def matvec (x ):
366
366
A = np .array ([[1.0 , 0 , 0 ], [0 , 2.0 , 0 ], [0 , 0 , 3.0 ]])
367
367
y , info = solver (A , x )
368
- assert_equal ( info , 0 )
368
+ assert info == 0
369
369
return y
370
370
b = np .array ([1 , 1. / 2 , 1. / 3 ])
371
371
op = LinearOperator ((3 , 3 ), matvec = matvec , rmatvec = matvec ,
@@ -375,7 +375,7 @@ def matvec(x):
375
375
pytest .raises (RuntimeError , solver , op , b )
376
376
else :
377
377
y , info = solver (op , b )
378
- assert_equal ( info , 0 )
378
+ assert info == 0
379
379
assert_allclose (y , [1 , 1 , 1 ])
380
380
381
381
@@ -411,8 +411,8 @@ def test_atol(solver):
411
411
x , info = solver (A , b , M1 = M , M2 = M2 , tol = tol , atol = atol )
412
412
else :
413
413
x , info = solver (A , b , M = M , tol = tol , atol = atol )
414
- assert_equal (info , 0 )
415
414
415
+ assert info == 0
416
416
residual = A @ x - b
417
417
err = np .linalg .norm (residual )
418
418
atol2 = tol * b_norm
@@ -436,11 +436,11 @@ def test_zero_rhs(solver):
436
436
sup .filter (DeprecationWarning , ".*called without specifying.*" )
437
437
438
438
x , info = solver (A , b , tol = tol )
439
- assert_equal ( info , 0 )
439
+ assert info == 0
440
440
assert_allclose (x , 0. , atol = 1e-15 )
441
441
442
442
x , info = solver (A , b , tol = tol , x0 = ones (10 ))
443
- assert_equal ( info , 0 )
443
+ assert info == 0
444
444
assert_allclose (x , 0. , atol = tol )
445
445
446
446
if solver is not minres :
@@ -449,11 +449,11 @@ def test_zero_rhs(solver):
449
449
assert_allclose (x , 0 )
450
450
451
451
x , info = solver (A , b , tol = tol , atol = tol )
452
- assert_equal ( info , 0 )
452
+ assert info == 0
453
453
assert_allclose (x , 0 , atol = 1e-300 )
454
454
455
455
x , info = solver (A , b , tol = tol , atol = 0 )
456
- assert_equal ( info , 0 )
456
+ assert info == 0
457
457
assert_allclose (x , 0 , atol = 1e-300 )
458
458
459
459
@@ -516,11 +516,11 @@ def test_x0_working(solver):
516
516
kw = dict (atol = 0 , tol = 1e-6 )
517
517
518
518
x , info = solver (A , b , ** kw )
519
- assert_equal ( info , 0 )
519
+ assert info == 0
520
520
assert np .linalg .norm (A @ x - b ) <= 1e-6 * np .linalg .norm (b )
521
521
522
522
x , info = solver (A , b , x0 = x0 , ** kw )
523
- assert_equal ( info , 0 )
523
+ assert info == 0
524
524
assert np .linalg .norm (A @ x - b ) <= 2e-6 * np .linalg .norm (b )
525
525
526
526
@@ -539,7 +539,7 @@ def test_x0_equals_Mb(solver):
539
539
x , info = solver (A , b , x0 = x0 , tol = tol )
540
540
541
541
assert_array_equal (x0 , 'Mb' ) # ensure that x0 is not overwritten
542
- assert_equal ( info , 0 )
542
+ assert info == 0
543
543
assert_normclose (A @ x , b , tol = tol )
544
544
545
545
@@ -556,12 +556,14 @@ def cb(x):
556
556
x , info = solver (A , b , callback = cb , show = True )
557
557
out , err = capsys .readouterr ()
558
558
if i == 20 : # Asymmetric and Positive Definite
559
- assert_equal (out , f"{ solverstring } : Linear solve not converged "
560
- f"due to reach MAXIT iterations { count [0 ]} \n " )
559
+ exp = (f"{ solverstring } : Linear solve not converged "
560
+ f"due to reach MAXIT iterations { count [0 ]} \n " )
561
+ assert out == exp
561
562
else : # 1-D Poisson equations
562
- assert_equal (out , f"{ solverstring } : Linear solve converged due to "
563
- f"reach TOL iterations { count [0 ]} \n " )
564
- assert_equal (err , '' )
563
+ exp = (f"{ solverstring } : Linear solve converged due to "
564
+ f"reach TOL iterations { count [0 ]} \n " )
565
+ assert out == exp
566
+ assert err == ""
565
567
566
568
567
569
# -----------------------------------------------------------------------------
@@ -607,7 +609,7 @@ def UT_solve(b):
607
609
sup .filter (DeprecationWarning , ".*called without specifying.*" )
608
610
x , info = qmr (A , b , tol = 1e-8 , maxiter = 15 , M1 = M1 , M2 = M2 )
609
611
610
- assert_equal ( info , 0 )
612
+ assert info == 0
611
613
assert_normclose (A @ x , b , tol = 1e-8 )
612
614
613
615
0 commit comments