24
24
# TODO test both preconditioner methods
25
25
26
26
27
+ # list of all solvers under test
28
+ _SOLVERS = [bicg , bicgstab , cg , cgs , gcrotmk , gmres , lgmres ,
29
+ minres , qmr , tfqmr ]
30
+ # create parametrized fixture for easy reuse in tests
31
+ @pytest .fixture (params = _SOLVERS , scope = "session" )
32
+ def solver (request ):
33
+ """
34
+ Fixture for all solvers in scipy.sparse.linalg._isolve
35
+ """
36
+ return request .param
37
+
38
+
27
39
class Case :
28
40
def __init__ (self , name , A , b = None , skip = None , nonconvergence = None ):
29
41
self .name = name
@@ -47,16 +59,11 @@ def __repr__(self):
47
59
48
60
class IterativeParams :
49
61
def __init__ (self ):
50
- # list of tuples (solver, symmetric, positive_definite )
51
- solvers = [cg , cgs , bicg , bicgstab , gmres , qmr , minres , lgmres ,
52
- gcrotmk , tfqmr ]
53
62
sym_solvers = [minres , cg ]
54
63
posdef_solvers = [cg ]
55
64
real_solvers = [minres ]
56
65
57
- self .solvers = solvers
58
-
59
- # list of tuples (A, symmetric, positive_definite )
66
+ # list of Cases
60
67
self .cases = []
61
68
62
69
# Symmetric and Positive Definite
@@ -66,7 +73,6 @@ def __init__(self):
66
73
data [1 , :] = - 1
67
74
data [2 , :] = - 1
68
75
Poisson1D = spdiags (data , [0 , - 1 , 1 ], N , N , format = 'csr' )
69
- self .Poisson1D = Case ("poisson1d" , Poisson1D )
70
76
self .cases .append (Case ("poisson1d" , Poisson1D ))
71
77
# note: minres fails for single precision
72
78
self .cases .append (Case ("poisson1d-F" , Poisson1D .astype ('f' ),
@@ -81,7 +87,6 @@ def __init__(self):
81
87
82
88
# 2-dimensional Poisson equations
83
89
Poisson2D = kronsum (Poisson1D , Poisson1D )
84
- self .Poisson2D = Case ("poisson2d" , Poisson2D )
85
90
# note: minres fails for 2-d poisson problem,
86
91
# it will be fixed in the future PR
87
92
self .cases .append (Case ("poisson2d" , Poisson2D , skip = [minres ]))
@@ -180,7 +185,13 @@ def __init__(self):
180
185
)
181
186
182
187
183
- params = IterativeParams ()
188
+ cases = IterativeParams ().cases
189
+ @pytest .fixture (params = cases , ids = [x .name for x in cases ], scope = "session" )
190
+ def case (request ):
191
+ """
192
+ Fixture for all cases in IterativeParams
193
+ """
194
+ return request .param
184
195
185
196
186
197
def check_maxiter (solver , case ):
@@ -201,14 +212,12 @@ def callback(x):
201
212
assert info == 1
202
213
203
214
204
- def test_maxiter ():
205
- for case in params .cases :
206
- for solver in params .solvers :
207
- if solver in case .skip + case .nonconvergence :
208
- continue
209
- with suppress_warnings () as sup :
210
- sup .filter (DeprecationWarning , ".*called without specifying.*" )
211
- check_maxiter (solver , case )
215
+ def test_maxiter (solver , case ):
216
+ if solver in case .skip + case .nonconvergence :
217
+ pytest .skip ("unsupported combination" )
218
+ with suppress_warnings () as sup :
219
+ sup .filter (DeprecationWarning , ".*called without specifying.*" )
220
+ check_maxiter (solver , case )
212
221
213
222
214
223
def assert_normclose (a , b , tol = 1e-8 ):
@@ -239,14 +248,12 @@ def check_convergence(solver, case):
239
248
assert np .linalg .norm (A @ x - b ) <= np .linalg .norm (b )
240
249
241
250
242
- def test_convergence ():
243
- for solver in params .solvers :
244
- for case in params .cases :
245
- if solver in case .skip :
246
- continue
247
- with suppress_warnings () as sup :
248
- sup .filter (DeprecationWarning , ".*called without specifying.*" )
249
- check_convergence (solver , case )
251
+ def test_convergence (solver , case ):
252
+ if solver in case .skip :
253
+ pytest .skip ("unsupported combination" )
254
+ with suppress_warnings () as sup :
255
+ sup .filter (DeprecationWarning , ".*called without specifying.*" )
256
+ check_convergence (solver , case )
250
257
251
258
252
259
def check_precond_dummy (solver , case ):
@@ -286,14 +293,12 @@ def identity(b, which=None):
286
293
assert_normclose (A @ x , b , tol = tol )
287
294
288
295
289
- def test_precond_dummy ():
290
- for case in params .cases :
291
- for solver in params .solvers :
292
- if solver in case .skip + case .nonconvergence :
293
- continue
294
- with suppress_warnings () as sup :
295
- sup .filter (DeprecationWarning , ".*called without specifying.*" )
296
- check_precond_dummy (solver , case )
296
+ def test_precond_dummy (solver , case ):
297
+ if solver in case .skip + case .nonconvergence :
298
+ pytest .skip ("unsupported combination" )
299
+ with suppress_warnings () as sup :
300
+ sup .filter (DeprecationWarning , ".*called without specifying.*" )
301
+ check_precond_dummy (solver , case )
297
302
298
303
299
304
def check_precond_inverse (solver , case ):
@@ -340,25 +345,20 @@ def rmatvec(b):
340
345
assert matvec_count [0 ] <= 3
341
346
342
347
343
- @pytest .mark .parametrize ("case" , [params .Poisson1D , params .Poisson2D ])
344
- def test_precond_inverse (case ):
345
- for solver in params .solvers :
346
- if solver in case .skip :
347
- continue
348
- if solver is qmr :
349
- continue
350
- with suppress_warnings () as sup :
351
- sup .filter (DeprecationWarning , ".*called without specifying.*" )
352
- check_precond_inverse (solver , case )
348
+ def test_precond_inverse (solver , case ):
349
+ if (solver in case .skip or solver is qmr
350
+ or case .name not in ("poisson1d" , "poisson2d" )):
351
+ pytest .skip ("unsupported combination" )
352
+ with suppress_warnings () as sup :
353
+ sup .filter (DeprecationWarning , ".*called without specifying.*" )
354
+ check_precond_inverse (solver , case )
353
355
354
356
355
- def test_reentrancy ():
356
- non_reentrant = [cg , cgs , bicg , bicgstab , gmres , qmr ]
357
+ def test_reentrancy (solver ):
357
358
reentrant = [lgmres , minres , gcrotmk , tfqmr ]
358
- for solver in reentrant + non_reentrant :
359
- with suppress_warnings () as sup :
360
- sup .filter (DeprecationWarning , ".*called without specifying.*" )
361
- _check_reentrancy (solver , solver in reentrant )
359
+ with suppress_warnings () as sup :
360
+ sup .filter (DeprecationWarning , ".*called without specifying.*" )
361
+ _check_reentrancy (solver , solver in reentrant )
362
362
363
363
364
364
def _check_reentrancy (solver , is_reentrant ):
@@ -379,11 +379,11 @@ def matvec(x):
379
379
assert_allclose (y , [1 , 1 , 1 ])
380
380
381
381
382
- @pytest .mark .parametrize ("solver" , [cg , cgs , bicg , bicgstab , gmres , qmr ,
383
- lgmres , gcrotmk ])
384
382
def test_atol (solver ):
385
- # TODO: minres. It didn't historically use absolute tolerances, so
383
+ # TODO: minres / tfqmr . It didn't historically use absolute tolerances, so
386
384
# fixing it is less urgent.
385
+ if solver in (minres , tfqmr ):
386
+ pytest .skip ("TODO" )
387
387
388
388
np .random .seed (1234 )
389
389
A = np .random .rand (10 , 10 )
@@ -421,8 +421,6 @@ def test_atol(solver):
421
421
assert err <= 1.00025 * max (atol , atol2 )
422
422
423
423
424
- @pytest .mark .parametrize ("solver" , [cg , cgs , bicg , bicgstab , gmres , qmr ,
425
- minres , lgmres , gcrotmk , tfqmr ])
426
424
def test_zero_rhs (solver ):
427
425
np .random .seed (1234 )
428
426
A = np .random .rand (10 , 10 )
@@ -457,25 +455,21 @@ def test_zero_rhs(solver):
457
455
assert_allclose (x , 0 , atol = 1e-300 )
458
456
459
457
460
- @pytest .mark .parametrize ("solver" , [
461
- pytest .param (gmres , marks = pytest .mark .xfail (platform .machine () == 'aarch64'
462
- and sys .version_info [1 ] == 9 ,
463
- reason = "gh-13019" )),
464
- qmr ,
465
- pytest .param (lgmres , marks = pytest .mark .xfail (
466
- platform .machine () not in ['x86_64' 'x86' , 'aarch64' , 'arm64' ],
467
- reason = "fails on at least ppc64le, ppc64 and riscv64, see gh-17839" )
468
- ),
469
- pytest .param (cgs , marks = pytest .mark .xfail ),
470
- pytest .param (bicg , marks = pytest .mark .xfail ),
471
- pytest .param (bicgstab , marks = pytest .mark .xfail ),
472
- pytest .param (gcrotmk , marks = pytest .mark .xfail ),
473
- pytest .param (tfqmr , marks = pytest .mark .xfail )])
474
458
def test_maxiter_worsening (solver ):
459
+ if solver not in (gmres , lgmres ):
460
+ # these were skipped from the very beginning, see gh-9201; gh-14160
461
+ pytest .skip ("unsupported combination" )
475
462
# Check error does not grow (boundlessly) with increasing maxiter.
476
463
# This can occur due to the solvers hitting close to breakdown,
477
464
# which they should detect and halt as necessary.
478
465
# cf. gh-9100
466
+ if (solver is gmres and platform .machine () == 'aarch64'
467
+ and sys .version_info [1 ] == 9 ):
468
+ pytest .xfail (reason = "gh-13019" )
469
+ if (solver is lgmres and
470
+ platform .machine () not in ['x86_64' 'x86' , 'aarch64' , 'arm64' ]):
471
+ # see gh-17839
472
+ pytest .xfail (reason = "fails on at least ppc64le, ppc64 and riscv64" )
479
473
480
474
# Singular matrix, rhs numerically not in range
481
475
A = np .array ([[- 0.1112795288033378 , 0 , 0 , 0.16127952880333685 ],
@@ -499,8 +493,6 @@ def test_maxiter_worsening(solver):
499
493
assert error <= tol * best_error
500
494
501
495
502
- @pytest .mark .parametrize ("solver" , [cg , cgs , bicg , bicgstab , gmres , qmr ,
503
- minres , lgmres , gcrotmk , tfqmr ])
504
496
def test_x0_working (solver ):
505
497
# Easy problem
506
498
np .random .seed (1 )
@@ -524,46 +516,40 @@ def test_x0_working(solver):
524
516
assert np .linalg .norm (A @ x - b ) <= 2e-6 * np .linalg .norm (b )
525
517
526
518
527
- @pytest .mark .parametrize ('solver' , [cg , cgs , bicg , bicgstab , gmres , qmr ,
528
- minres , lgmres , gcrotmk ])
529
- def test_x0_equals_Mb (solver ):
530
- for case in params .cases :
531
- if solver in case .skip :
532
- continue
533
- with suppress_warnings () as sup :
534
- sup .filter (DeprecationWarning , ".*called without specifying.*" )
535
- A = case .A
536
- b = case .b
537
- x0 = 'Mb'
538
- tol = 1e-8
539
- x , info = solver (A , b , x0 = x0 , tol = tol )
519
+ def test_x0_equals_Mb (solver , case ):
520
+ if solver in case .skip or solver is tfqmr :
521
+ pytest .skip ("unsupported combination" )
522
+ with suppress_warnings () as sup :
523
+ sup .filter (DeprecationWarning , ".*called without specifying.*" )
524
+ A = case .A
525
+ b = case .b
526
+ x0 = 'Mb'
527
+ tol = 1e-8
528
+ x , info = solver (A , b , x0 = x0 , tol = tol )
540
529
541
- assert_array_equal (x0 , 'Mb' ) # ensure that x0 is not overwritten
542
- assert info == 0
543
- assert_normclose (A @ x , b , tol = tol )
530
+ assert_array_equal (x0 , 'Mb' ) # ensure that x0 is not overwritten
531
+ assert info == 0
532
+ assert_normclose (A @ x , b , tol = tol )
544
533
545
534
546
- @pytest .mark .parametrize (('solver' , 'solverstring' ), [(tfqmr , 'TFQMR' )])
547
- def test_show (solver , solverstring , capsys ):
535
+ def test_show (case , capsys ):
548
536
def cb (x ):
549
- count [0 ] += 1
550
-
551
- for i in [0 , 20 ]:
552
- case = params .cases [i ]
553
- A = case .A
554
- b = case .b
555
- count = [0 ]
556
- x , info = solver (A , b , callback = cb , show = True )
557
- out , err = capsys .readouterr ()
558
- if i == 20 : # Asymmetric and Positive Definite
559
- exp = (f"{ solverstring } : Linear solve not converged "
560
- f"due to reach MAXIT iterations { count [0 ]} \n " )
561
- assert out == exp
562
- else : # 1-D Poisson equations
563
- exp = (f"{ solverstring } : Linear solve converged due to "
564
- f"reach TOL iterations { count [0 ]} \n " )
565
- assert out == exp
566
- assert err == ""
537
+ pass
538
+
539
+ x , info = tfqmr (case .A , case .b , callback = cb , show = True )
540
+ out , err = capsys .readouterr ()
541
+
542
+ if case .name == "sym-nonpd" :
543
+ # no logs for some reason
544
+ exp = ""
545
+ elif case .name in ("nonsymposdef" , "nonsymposdef-F" ):
546
+ # Asymmetric and Positive Definite
547
+ exp = "TFQMR: Linear solve not converged due to reach MAXIT iterations"
548
+ else : # all other cases
549
+ exp = "TFQMR: Linear solve converged due to reach TOL iterations"
550
+
551
+ assert out .startswith (exp )
552
+ assert err == ""
567
553
568
554
569
555
# -----------------------------------------------------------------------------
0 commit comments