Commit 9045b9f
ENH: Enable RunsOn GPU support for lecture builds (#437)
* Enable RunsOn GPU support for lecture builds
- Add scripts/test-jax-install.py to verify JAX/GPU installation
- Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration
- Update cache.yml to use RunsOn g4dn.2xlarge GPU runner
- Update ci.yml to use RunsOn g4dn.2xlarge GPU runner
- Update publish.yml to use RunsOn g4dn.2xlarge GPU runner
- Install JAX with CUDA 13 support and Numpyro on all workflows
- Add nvidia-smi check to verify GPU availability
This mirrors the setup used in lecture-python.myst repository.
* DOC: Update JAX lectures with GPU admonition and narrative
- Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md
- Update introduction in jax_intro.md to reflect GPU access
- Update conditional GPU language to reflect lectures now run on GPU
- Following QuantEcon style guide for JAX lectures
* DEBUG: Add hardware benchmark script to diagnose performance
- Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks
- Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners
- Include warm-up vs compiled timing to isolate JIT overhead
- Add system info collection (CPU model, frequency, GPU detection)
* Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book)
* Fix: Add content to benchmark-jupyter.ipynb (was empty)
* Fix: Add benchmark content to benchmark-jupyter.ipynb
* Add JSON output to benchmarks and upload as artifacts
- Update benchmark-hardware.py to save results to JSON
- Update benchmark-jupyter.ipynb to save results to JSON
- Update benchmark-jupyterbook.md to save results to JSON
- Add CI step to collect and display benchmark results
- Add CI step to upload benchmark results as artifact
* Fix syntax errors in benchmark-hardware.py
- Remove extra triple quote at start of file
- Remove stray parentheses in result assignments
* Sync benchmark scripts with CPU branch for comparable results
- Copy benchmark-hardware.py from debug/benchmark-github-actions
- Copy benchmark-jupyter.ipynb from debug/benchmark-github-actions
- Copy benchmark-jupyterbook.md from debug/benchmark-github-actions
- Update ci.yml to use matching file names
The test scripts are now identical between both branches,
only the CI workflow differs (runner type and JAX installation).
* ENH: Force lax.scan sequential operation to run on CPU
Add device=cpu to the qm_jax function decorator to avoid the known
XLA limitation where lax.scan with millions of lightweight iterations
performs poorly on GPU due to CPU-GPU synchronization overhead.
Added explanatory note about this pattern.
Co-authored-by: HumphreyYang <Humphrey.Yang@anu.edu.au>
* update note
* Add lax.scan profiler to CI for GPU debugging
- Add scripts/profile_lax_scan.py: Profiles lax.scan performance on GPU vs CPU
to investigate the synchronization overhead issue (JAX Issue #2491)
- Add CI step to run profiler with 100K iterations on RunsOn GPU environment
- Script supports multiple profiling modes: basic timing, Nsight, JAX profiler, XLA dumps
* Add diagnostic mode to lax.scan profiler
- Add --diagnose flag that tests time scaling across iteration counts
- If time scales linearly with iterations (not compute), it proves
constant per-iteration overhead (CPU-GPU synchronization)
- Also add --verbose flag for CUDA/XLA logging
- Update CI to run with --diagnose flag
* Add Nsight Systems profiling to CI
- Run nsys profile with 1000 iterations if nsys is available
- Captures CUDA, NVTX, and OS runtime traces
- Uploads .nsys-rep file as artifact for visual analysis
- continue-on-error: true so CI doesn't fail if nsys unavailable
* address @jstac comment
* Improve JAX lecture content and pedagogy
- Reorganize jax_intro.md to introduce JAX features upfront with clearer structure
- Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff)
- Add explicit GPU performance notes in vmap sections
- Enhance vmap explanation with detailed function composition breakdown
- Clarify memory efficiency tradeoffs between different vmap approaches
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
* Remove benchmark scripts (moved to QuantEcon/benchmarks)
- Remove profile_lax_scan.py, benchmark-hardware.py, benchmark-jupyter.ipynb, benchmark-jupyterbook.md
- Remove profiling/benchmarking steps from ci.yml
- Keep test-jax-install.py for JAX installation verification
Benchmark scripts are now maintained in: https://github.com/QuantEcon/benchmarks
* Update lectures/numpy_vs_numba_vs_jax.md
* Add GPU and JAX hardware details to status page
- Add nvidia-smi output to show GPU availability
- Add JAX backend check to confirm GPU usage
- Matches format used in lecture-python.myst
---------
Co-authored-by: HumphreyYang <Humphrey.Yang@anu.edu.au>
Co-authored-by: Humphrey Yang <u6474961@anu.edu.au>
Co-authored-by: John Stachurski <john.stachurski@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>1 parent a4b89d4 commit 9045b9fCopy full SHA for 9045b9f
File tree
Expand file treeCollapse file tree
7 files changed
+141
-42
lines changedOpen diff view settings
Filter options
- .github/workflows
- lectures
- scripts
Expand file treeCollapse file tree
7 files changed
+141
-42
lines changedOpen diff view settings
Collapse file
.github/workflows/cache.yml
Copy file name to clipboardExpand all lines: .github/workflows/cache.yml+11-1Lines changed: 11 additions & 1 deletion
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
6 | 6 | |
7 | 7 | |
8 | 8 | |
9 | | - |
| 9 | + |
10 | 10 | |
11 | 11 | |
12 | 12 | |
| ||
18 | 18 | |
19 | 19 | |
20 | 20 | |
| 21 | + |
| 22 | + |
| 23 | + |
| 24 | + |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | + |
21 | 31 | |
22 | 32 | |
23 | 33 | |
|
Collapse file
+10-1Lines changed: 10 additions & 1 deletion
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
2 | 2 | |
3 | 3 | |
4 | 4 | |
5 | | - |
| 5 | + |
6 | 6 | |
7 | 7 | |
8 | 8 | |
| ||
16 | 16 | |
17 | 17 | |
18 | 18 | |
| 19 | + |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | + |
| 24 | + |
| 25 | + |
| 26 | + |
| 27 | + |
19 | 28 | |
20 | 29 | |
21 | 30 | |
|
Collapse file
.github/workflows/publish.yml
Copy file name to clipboardExpand all lines: .github/workflows/publish.yml+11-1Lines changed: 11 additions & 1 deletion
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
6 | 6 | |
7 | 7 | |
8 | 8 | |
9 | | - |
| 9 | + |
10 | 10 | |
11 | 11 | |
12 | 12 | |
| ||
21 | 21 | |
22 | 22 | |
23 | 23 | |
| 24 | + |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | + |
| 32 | + |
| 33 | + |
24 | 34 | |
25 | 35 | |
26 | 36 | |
|
Collapse file
+26-26Lines changed: 26 additions & 26 deletions
- Display the source diff
- Display the rich diff
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
13 | 13 | |
14 | 14 | |
15 | 15 | |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | + |
| 24 | + |
| 25 | + |
| 26 | + |
| 27 | + |
16 | 28 | |
17 | 29 | |
18 | 30 | |
| ||
21 | 33 | |
22 | 34 | |
23 | 35 | |
24 | | - |
25 | | - |
26 | | - |
27 | | - |
28 | | - |
29 | | - |
30 | | - |
31 | | - |
| 36 | + |
| 37 | + |
32 | 38 | |
33 | | - |
34 | | - |
35 | | - |
| 39 | + |
36 | 40 | |
37 | | - |
| 41 | + |
| 42 | + |
38 | 43 | |
| 44 | + |
| 45 | + |
| 46 | + |
39 | 47 | |
40 | 48 | |
41 | 49 | |
42 | | - |
43 | | - |
| 50 | + |
| 51 | + |
44 | 52 | |
45 | | - |
| 53 | + |
46 | 54 | |
47 | 55 | |
48 | 56 | |
| ||
523 | 531 | |
524 | 532 | |
525 | 533 | |
526 | | - |
527 | | - |
528 | | - |
529 | | - |
530 | | - |
531 | | - |
532 | | - |
| 534 | + |
533 | 535 | |
534 | | - |
535 | | - |
| 536 | + |
536 | 537 | |
537 | 538 | |
538 | 539 | |
| ||
634 | 635 | |
635 | 636 | |
636 | 637 | |
637 | | - |
638 | | - |
| 638 | + |
639 | 639 | |
640 | 640 | |
641 | 641 | |
|
Collapse file
lectures/numpy_vs_numba_vs_jax.md
Copy file name to clipboardExpand all lines: lectures/numpy_vs_numba_vs_jax.md+48-13Lines changed: 48 additions & 13 deletions
- Display the source diff
- Display the rich diff
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
48 | 48 | |
49 | 49 | |
50 | 50 | |
| 51 | + |
| 52 | + |
| 53 | + |
| 54 | + |
| 55 | + |
| 56 | + |
| 57 | + |
| 58 | + |
| 59 | + |
| 60 | + |
| 61 | + |
| 62 | + |
51 | 63 | |
52 | 64 | |
53 | 65 | |
| ||
317 | 329 | |
318 | 330 | |
319 | 331 | |
320 | | - |
| 332 | + |
321 | 333 | |
322 | 334 | |
323 | 335 | |
| ||
370 | 382 | |
371 | 383 | |
372 | 384 | |
373 | | - |
374 | | - |
| 385 | + |
| 386 | + |
| 387 | + |
375 | 388 | |
376 | | - |
377 | | - |
| 389 | + |
378 | 390 | |
379 | | - |
| 391 | + |
| 392 | + |
| 393 | + |
| 394 | + |
| 395 | + |
380 | 396 | |
381 | 397 | |
382 | 398 | |
383 | 399 | |
384 | 400 | |
385 | 401 | |
386 | | - |
| 402 | + |
387 | 403 | |
388 | 404 | |
389 | | - |
| 405 | + |
| 406 | + |
| 407 | + |
390 | 408 | |
391 | 409 | |
392 | 410 | |
| ||
399 | 417 | |
400 | 418 | |
401 | 419 | |
402 | | - |
| 420 | + |
| 421 | + |
| 422 | + |
| 423 | + |
| 424 | + |
| 425 | + |
| 426 | + |
| 427 | + |
403 | 428 | |
404 | 429 | |
405 | 430 | |
406 | 431 | |
407 | 432 | |
408 | 433 | |
409 | | - |
410 | 434 | |
411 | 435 | |
412 | 436 | |
413 | 437 | |
414 | 438 | |
415 | 439 | |
416 | 440 | |
417 | | - |
418 | | - |
| 441 | + |
419 | 442 | |
420 | 443 | |
421 | 444 | |
| ||
497 | 520 | |
498 | 521 | |
499 | 522 | |
500 | | - |
| 523 | + |
| 524 | + |
| 525 | + |
501 | 526 | |
502 | 527 | |
503 | 528 | |
| ||
509 | 534 | |
510 | 535 | |
511 | 536 | |
| 537 | + |
| 538 | + |
| 539 | + |
| 540 | + |
| 541 | + |
| 542 | + |
| 543 | + |
| 544 | + |
| 545 | + |
| 546 | + |
512 | 547 | |
513 | 548 | |
514 | 549 | |
|
Collapse file
+14Lines changed: 14 additions & 0 deletions
- Display the source diff
- Display the rich diff
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
31 | 31 | |
32 | 32 | |
33 | 33 | |
| 34 | + |
| 35 | + |
| 36 | + |
| 37 | + |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | + |
| 42 | + |
| 43 | + |
| 44 | + |
| 45 | + |
| 46 | + |
| 47 | + |
34 | 48 | |
Collapse file
scripts/test-jax-install.py
Copy file name to clipboard+21Lines changed: 21 additions & 0 deletions
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
| 1 | + |
| 2 | + |
| 3 | + |
| 4 | + |
| 5 | + |
| 6 | + |
| 7 | + |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | + |
| 12 | + |
| 13 | + |
| 14 | + |
| 15 | + |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | + |
| 21 | + |
0 commit comments