Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 33cb520

Browse filesBrowse files
committed
MAINT: parametrize scipy/sparse/linalg/_isolve/tests/test_iterative.py
1 parent c1a819c commit 33cb520
Copy full SHA for 33cb520

File tree

Expand file treeCollapse file tree

1 file changed

+92
-106
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+92
-106
lines changed

‎scipy/sparse/linalg/_isolve/tests/test_iterative.py

Copy file name to clipboardExpand all lines: scipy/sparse/linalg/_isolve/tests/test_iterative.py
+92-106Lines changed: 92 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@
2424
# TODO test both preconditioner methods
2525

2626

27+
# list of all solvers under test
28+
_SOLVERS = [bicg, bicgstab, cg, cgs, gcrotmk, gmres, lgmres,
29+
minres, qmr, tfqmr]
30+
# create parametrized fixture for easy reuse in tests
31+
@pytest.fixture(params=_SOLVERS, scope="session")
32+
def solver(request):
33+
"""
34+
Fixture for all solvers in scipy.sparse.linalg._isolve
35+
"""
36+
return request.param
37+
38+
2739
class Case:
2840
def __init__(self, name, A, b=None, skip=None, nonconvergence=None):
2941
self.name = name
@@ -47,16 +59,11 @@ def __repr__(self):
4759

4860
class IterativeParams:
4961
def __init__(self):
50-
# list of tuples (solver, symmetric, positive_definite )
51-
solvers = [cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres,
52-
gcrotmk, tfqmr]
5362
sym_solvers = [minres, cg]
5463
posdef_solvers = [cg]
5564
real_solvers = [minres]
5665

57-
self.solvers = solvers
58-
59-
# list of tuples (A, symmetric, positive_definite )
66+
# list of Cases
6067
self.cases = []
6168

6269
# Symmetric and Positive Definite
@@ -66,7 +73,6 @@ def __init__(self):
6673
data[1, :] = -1
6774
data[2, :] = -1
6875
Poisson1D = spdiags(data, [0, -1, 1], N, N, format='csr')
69-
self.Poisson1D = Case("poisson1d", Poisson1D)
7076
self.cases.append(Case("poisson1d", Poisson1D))
7177
# note: minres fails for single precision
7278
self.cases.append(Case("poisson1d-F", Poisson1D.astype('f'),
@@ -81,7 +87,6 @@ def __init__(self):
8187

8288
# 2-dimensional Poisson equations
8389
Poisson2D = kronsum(Poisson1D, Poisson1D)
84-
self.Poisson2D = Case("poisson2d", Poisson2D)
8590
# note: minres fails for 2-d poisson problem,
8691
# it will be fixed in the future PR
8792
self.cases.append(Case("poisson2d", Poisson2D, skip=[minres]))
@@ -180,7 +185,13 @@ def __init__(self):
180185
)
181186

182187

183-
params = IterativeParams()
188+
cases = IterativeParams().cases
189+
@pytest.fixture(params=cases, ids=[x.name for x in cases], scope="session")
190+
def case(request):
191+
"""
192+
Fixture for all cases in IterativeParams
193+
"""
194+
return request.param
184195

185196

186197
def check_maxiter(solver, case):
@@ -201,14 +212,12 @@ def callback(x):
201212
assert info == 1
202213

203214

204-
def test_maxiter():
205-
for case in params.cases:
206-
for solver in params.solvers:
207-
if solver in case.skip + case.nonconvergence:
208-
continue
209-
with suppress_warnings() as sup:
210-
sup.filter(DeprecationWarning, ".*called without specifying.*")
211-
check_maxiter(solver, case)
215+
def test_maxiter(solver, case):
216+
if solver in case.skip + case.nonconvergence:
217+
pytest.skip("unsupported combination")
218+
with suppress_warnings() as sup:
219+
sup.filter(DeprecationWarning, ".*called without specifying.*")
220+
check_maxiter(solver, case)
212221

213222

214223
def assert_normclose(a, b, tol=1e-8):
@@ -239,14 +248,12 @@ def check_convergence(solver, case):
239248
assert np.linalg.norm(A @ x - b) <= np.linalg.norm(b)
240249

241250

242-
def test_convergence():
243-
for solver in params.solvers:
244-
for case in params.cases:
245-
if solver in case.skip:
246-
continue
247-
with suppress_warnings() as sup:
248-
sup.filter(DeprecationWarning, ".*called without specifying.*")
249-
check_convergence(solver, case)
251+
def test_convergence(solver, case):
252+
if solver in case.skip:
253+
pytest.skip("unsupported combination")
254+
with suppress_warnings() as sup:
255+
sup.filter(DeprecationWarning, ".*called without specifying.*")
256+
check_convergence(solver, case)
250257

251258

252259
def check_precond_dummy(solver, case):
@@ -286,14 +293,12 @@ def identity(b, which=None):
286293
assert_normclose(A @ x, b, tol=tol)
287294

288295

289-
def test_precond_dummy():
290-
for case in params.cases:
291-
for solver in params.solvers:
292-
if solver in case.skip + case.nonconvergence:
293-
continue
294-
with suppress_warnings() as sup:
295-
sup.filter(DeprecationWarning, ".*called without specifying.*")
296-
check_precond_dummy(solver, case)
296+
def test_precond_dummy(solver, case):
297+
if solver in case.skip + case.nonconvergence:
298+
pytest.skip("unsupported combination")
299+
with suppress_warnings() as sup:
300+
sup.filter(DeprecationWarning, ".*called without specifying.*")
301+
check_precond_dummy(solver, case)
297302

298303

299304
def check_precond_inverse(solver, case):
@@ -340,25 +345,20 @@ def rmatvec(b):
340345
assert matvec_count[0] <= 3
341346

342347

343-
@pytest.mark.parametrize("case", [params.Poisson1D, params.Poisson2D])
344-
def test_precond_inverse(case):
345-
for solver in params.solvers:
346-
if solver in case.skip:
347-
continue
348-
if solver is qmr:
349-
continue
350-
with suppress_warnings() as sup:
351-
sup.filter(DeprecationWarning, ".*called without specifying.*")
352-
check_precond_inverse(solver, case)
348+
def test_precond_inverse(solver, case):
349+
if (solver in case.skip or solver is qmr
350+
or case.name not in ("poisson1d", "poisson2d")):
351+
pytest.skip("unsupported combination")
352+
with suppress_warnings() as sup:
353+
sup.filter(DeprecationWarning, ".*called without specifying.*")
354+
check_precond_inverse(solver, case)
353355

354356

355-
def test_reentrancy():
356-
non_reentrant = [cg, cgs, bicg, bicgstab, gmres, qmr]
357+
def test_reentrancy(solver):
357358
reentrant = [lgmres, minres, gcrotmk, tfqmr]
358-
for solver in reentrant + non_reentrant:
359-
with suppress_warnings() as sup:
360-
sup.filter(DeprecationWarning, ".*called without specifying.*")
361-
_check_reentrancy(solver, solver in reentrant)
359+
with suppress_warnings() as sup:
360+
sup.filter(DeprecationWarning, ".*called without specifying.*")
361+
_check_reentrancy(solver, solver in reentrant)
362362

363363

364364
def _check_reentrancy(solver, is_reentrant):
@@ -379,11 +379,11 @@ def matvec(x):
379379
assert_allclose(y, [1, 1, 1])
380380

381381

382-
@pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr,
383-
lgmres, gcrotmk])
384382
def test_atol(solver):
385-
# TODO: minres. It didn't historically use absolute tolerances, so
383+
# TODO: minres / tfqmr. It didn't historically use absolute tolerances, so
386384
# fixing it is less urgent.
385+
if solver in (minres, tfqmr):
386+
pytest.skip("TODO")
387387

388388
np.random.seed(1234)
389389
A = np.random.rand(10, 10)
@@ -421,8 +421,6 @@ def test_atol(solver):
421421
assert err <= 1.00025 * max(atol, atol2)
422422

423423

424-
@pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr,
425-
minres, lgmres, gcrotmk, tfqmr])
426424
def test_zero_rhs(solver):
427425
np.random.seed(1234)
428426
A = np.random.rand(10, 10)
@@ -457,25 +455,21 @@ def test_zero_rhs(solver):
457455
assert_allclose(x, 0, atol=1e-300)
458456

459457

460-
@pytest.mark.parametrize("solver", [
461-
pytest.param(gmres, marks=pytest.mark.xfail(platform.machine() == 'aarch64'
462-
and sys.version_info[1] == 9,
463-
reason="gh-13019")),
464-
qmr,
465-
pytest.param(lgmres, marks=pytest.mark.xfail(
466-
platform.machine() not in ['x86_64' 'x86', 'aarch64', 'arm64'],
467-
reason="fails on at least ppc64le, ppc64 and riscv64, see gh-17839")
468-
),
469-
pytest.param(cgs, marks=pytest.mark.xfail),
470-
pytest.param(bicg, marks=pytest.mark.xfail),
471-
pytest.param(bicgstab, marks=pytest.mark.xfail),
472-
pytest.param(gcrotmk, marks=pytest.mark.xfail),
473-
pytest.param(tfqmr, marks=pytest.mark.xfail)])
474458
def test_maxiter_worsening(solver):
459+
if solver not in (gmres, lgmres):
460+
# these were skipped from the very beginning, see gh-9201; gh-14160
461+
pytest.skip("unsupported combination")
475462
# Check error does not grow (boundlessly) with increasing maxiter.
476463
# This can occur due to the solvers hitting close to breakdown,
477464
# which they should detect and halt as necessary.
478465
# cf. gh-9100
466+
if (solver is gmres and platform.machine() == 'aarch64'
467+
and sys.version_info[1] == 9):
468+
pytest.xfail(reason="gh-13019")
469+
if (solver is lgmres and
470+
platform.machine() not in ['x86_64' 'x86', 'aarch64', 'arm64']):
471+
# see gh-17839
472+
pytest.xfail(reason="fails on at least ppc64le, ppc64 and riscv64")
479473

480474
# Singular matrix, rhs numerically not in range
481475
A = np.array([[-0.1112795288033378, 0, 0, 0.16127952880333685],
@@ -499,8 +493,6 @@ def test_maxiter_worsening(solver):
499493
assert error <= tol * best_error
500494

501495

502-
@pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr,
503-
minres, lgmres, gcrotmk, tfqmr])
504496
def test_x0_working(solver):
505497
# Easy problem
506498
np.random.seed(1)
@@ -524,46 +516,40 @@ def test_x0_working(solver):
524516
assert np.linalg.norm(A @ x - b) <= 2e-6 * np.linalg.norm(b)
525517

526518

527-
@pytest.mark.parametrize('solver', [cg, cgs, bicg, bicgstab, gmres, qmr,
528-
minres, lgmres, gcrotmk])
529-
def test_x0_equals_Mb(solver):
530-
for case in params.cases:
531-
if solver in case.skip:
532-
continue
533-
with suppress_warnings() as sup:
534-
sup.filter(DeprecationWarning, ".*called without specifying.*")
535-
A = case.A
536-
b = case.b
537-
x0 = 'Mb'
538-
tol = 1e-8
539-
x, info = solver(A, b, x0=x0, tol=tol)
519+
def test_x0_equals_Mb(solver, case):
520+
if solver in case.skip or solver is tfqmr:
521+
pytest.skip("unsupported combination")
522+
with suppress_warnings() as sup:
523+
sup.filter(DeprecationWarning, ".*called without specifying.*")
524+
A = case.A
525+
b = case.b
526+
x0 = 'Mb'
527+
tol = 1e-8
528+
x, info = solver(A, b, x0=x0, tol=tol)
540529

541-
assert_array_equal(x0, 'Mb') # ensure that x0 is not overwritten
542-
assert info == 0
543-
assert_normclose(A @ x, b, tol=tol)
530+
assert_array_equal(x0, 'Mb') # ensure that x0 is not overwritten
531+
assert info == 0
532+
assert_normclose(A @ x, b, tol=tol)
544533

545534

546-
@pytest.mark.parametrize(('solver', 'solverstring'), [(tfqmr, 'TFQMR')])
547-
def test_show(solver, solverstring, capsys):
535+
def test_show(case, capsys):
548536
def cb(x):
549-
count[0] += 1
550-
551-
for i in [0, 20]:
552-
case = params.cases[i]
553-
A = case.A
554-
b = case.b
555-
count = [0]
556-
x, info = solver(A, b, callback=cb, show=True)
557-
out, err = capsys.readouterr()
558-
if i == 20: # Asymmetric and Positive Definite
559-
exp = (f"{solverstring}: Linear solve not converged "
560-
f"due to reach MAXIT iterations {count[0]}\n")
561-
assert out == exp
562-
else: # 1-D Poisson equations
563-
exp = (f"{solverstring}: Linear solve converged due to "
564-
f"reach TOL iterations {count[0]}\n")
565-
assert out == exp
566-
assert err == ""
537+
pass
538+
539+
x, info = tfqmr(case.A, case.b, callback=cb, show=True)
540+
out, err = capsys.readouterr()
541+
542+
if case.name == "sym-nonpd":
543+
# no logs for some reason
544+
exp = ""
545+
elif case.name in ("nonsymposdef", "nonsymposdef-F"):
546+
# Asymmetric and Positive Definite
547+
exp = "TFQMR: Linear solve not converged due to reach MAXIT iterations"
548+
else: # all other cases
549+
exp = "TFQMR: Linear solve converged due to reach TOL iterations"
550+
551+
assert out.startswith(exp)
552+
assert err == ""
567553

568554

569555
# -----------------------------------------------------------------------------

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.