Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9045b9f

Browse filesBrowse files
mmckyHumphreyYangjstacclaude
authored
ENH: Enable RunsOn GPU support for lecture builds (#437)
* Enable RunsOn GPU support for lecture builds - Add scripts/test-jax-install.py to verify JAX/GPU installation - Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration - Update cache.yml to use RunsOn g4dn.2xlarge GPU runner - Update ci.yml to use RunsOn g4dn.2xlarge GPU runner - Update publish.yml to use RunsOn g4dn.2xlarge GPU runner - Install JAX with CUDA 13 support and Numpyro on all workflows - Add nvidia-smi check to verify GPU availability This mirrors the setup used in lecture-python.myst repository. * DOC: Update JAX lectures with GPU admonition and narrative - Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md - Update introduction in jax_intro.md to reflect GPU access - Update conditional GPU language to reflect lectures now run on GPU - Following QuantEcon style guide for JAX lectures * DEBUG: Add hardware benchmark script to diagnose performance - Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks - Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners - Include warm-up vs compiled timing to isolate JIT overhead - Add system info collection (CPU model, frequency, GPU detection) * Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book) * Fix: Add content to benchmark-jupyter.ipynb (was empty) * Fix: Add benchmark content to benchmark-jupyter.ipynb * Add JSON output to benchmarks and upload as artifacts - Update benchmark-hardware.py to save results to JSON - Update benchmark-jupyter.ipynb to save results to JSON - Update benchmark-jupyterbook.md to save results to JSON - Add CI step to collect and display benchmark results - Add CI step to upload benchmark results as artifact * Fix syntax errors in benchmark-hardware.py - Remove extra triple quote at start of file - Remove stray parentheses in result assignments * Sync benchmark scripts with CPU branch for comparable results - Copy benchmark-hardware.py from debug/benchmark-github-actions - Copy benchmark-jupyter.ipynb from debug/benchmark-github-actions - Copy benchmark-jupyterbook.md from debug/benchmark-github-actions - Update ci.yml to use matching file names The test scripts are now identical between both branches, only the CI workflow differs (runner type and JAX installation). * ENH: Force lax.scan sequential operation to run on CPU Add device=cpu to the qm_jax function decorator to avoid the known XLA limitation where lax.scan with millions of lightweight iterations performs poorly on GPU due to CPU-GPU synchronization overhead. Added explanatory note about this pattern. Co-authored-by: HumphreyYang <Humphrey.Yang@anu.edu.au> * update note * Add lax.scan profiler to CI for GPU debugging - Add scripts/profile_lax_scan.py: Profiles lax.scan performance on GPU vs CPU to investigate the synchronization overhead issue (JAX Issue #2491) - Add CI step to run profiler with 100K iterations on RunsOn GPU environment - Script supports multiple profiling modes: basic timing, Nsight, JAX profiler, XLA dumps * Add diagnostic mode to lax.scan profiler - Add --diagnose flag that tests time scaling across iteration counts - If time scales linearly with iterations (not compute), it proves constant per-iteration overhead (CPU-GPU synchronization) - Also add --verbose flag for CUDA/XLA logging - Update CI to run with --diagnose flag * Add Nsight Systems profiling to CI - Run nsys profile with 1000 iterations if nsys is available - Captures CUDA, NVTX, and OS runtime traces - Uploads .nsys-rep file as artifact for visual analysis - continue-on-error: true so CI doesn't fail if nsys unavailable * address @jstac comment * Improve JAX lecture content and pedagogy - Reorganize jax_intro.md to introduce JAX features upfront with clearer structure - Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff) - Add explicit GPU performance notes in vmap sections - Enhance vmap explanation with detailed function composition breakdown - Clarify memory efficiency tradeoffs between different vmap approaches 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Remove benchmark scripts (moved to QuantEcon/benchmarks) - Remove profile_lax_scan.py, benchmark-hardware.py, benchmark-jupyter.ipynb, benchmark-jupyterbook.md - Remove profiling/benchmarking steps from ci.yml - Keep test-jax-install.py for JAX installation verification Benchmark scripts are now maintained in: https://github.com/QuantEcon/benchmarks * Update lectures/numpy_vs_numba_vs_jax.md * Add GPU and JAX hardware details to status page - Add nvidia-smi output to show GPU availability - Add JAX backend check to confirm GPU usage - Matches format used in lecture-python.myst --------- Co-authored-by: HumphreyYang <Humphrey.Yang@anu.edu.au> Co-authored-by: Humphrey Yang <u6474961@anu.edu.au> Co-authored-by: John Stachurski <john.stachurski@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent a4b89d4 commit 9045b9f
Copy full SHA for 9045b9f

File tree

Expand file treeCollapse file tree

7 files changed

+141
-42
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

7 files changed

+141
-42
lines changed
Open diff view settings
Collapse file

‎.github/workflows/cache.yml‎

Copy file name to clipboardExpand all lines: .github/workflows/cache.yml
+11-1Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
workflow_dispatch:
77
jobs:
88
cache:
9-
runs-on: ubuntu-latest
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1010
steps:
1111
- uses: actions/checkout@v6
1212
- name: Setup Anaconda
@@ -18,6 +18,16 @@ jobs:
1818
python-version: "3.13"
1919
environment-file: environment.yml
2020
activate-environment: quantecon
21+
- name: Install JAX and Numpyro
22+
shell: bash -l {0}
23+
run: |
24+
pip install -U "jax[cuda13]"
25+
pip install numpyro
26+
python scripts/test-jax-install.py
27+
- name: Check nvidia drivers
28+
shell: bash -l {0}
29+
run: |
30+
nvidia-smi
2131
- name: Build HTML
2232
shell: bash -l {0}
2333
run: |
Collapse file

‎.github/workflows/ci.yml‎

Copy file name to clipboardExpand all lines: .github/workflows/ci.yml
+10-1Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Build Project [using jupyter-book]
22
on: [pull_request]
33
jobs:
44
preview:
5-
runs-on: ubuntu-latest
5+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
66
steps:
77
- uses: actions/checkout@v6
88
with:
@@ -16,6 +16,15 @@ jobs:
1616
python-version: "3.13"
1717
environment-file: environment.yml
1818
activate-environment: quantecon
19+
- name: Check nvidia Drivers
20+
shell: bash -l {0}
21+
run: nvidia-smi
22+
- name: Install JAX and Numpyro
23+
shell: bash -l {0}
24+
run: |
25+
pip install -U "jax[cuda13]"
26+
pip install numpyro
27+
python scripts/test-jax-install.py
1928
- name: Install latex dependencies
2029
run: |
2130
sudo apt-get -qq update
Collapse file

‎.github/workflows/publish.yml‎

Copy file name to clipboardExpand all lines: .github/workflows/publish.yml
+11-1Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
jobs:
77
publish:
88
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
9-
runs-on: ubuntu-latest
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1010
steps:
1111
- name: Checkout
1212
uses: actions/checkout@v6
@@ -21,6 +21,16 @@ jobs:
2121
python-version: "3.13"
2222
environment-file: environment.yml
2323
activate-environment: quantecon
24+
- name: Install JAX and Numpyro
25+
shell: bash -l {0}
26+
run: |
27+
pip install -U "jax[cuda13]"
28+
pip install numpyro
29+
python scripts/test-jax-install.py
30+
- name: Check nvidia drivers
31+
shell: bash -l {0}
32+
run: |
33+
nvidia-smi
2434
- name: Install latex dependencies
2535
run: |
2636
sudo apt-get -qq update
Collapse file

‎lectures/jax_intro.md‎

Copy file name to clipboardExpand all lines: lectures/jax_intro.md
+26-26Lines changed: 26 additions & 26 deletions
  • Display the source diff
  • Display the rich diff
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@ kernelspec:
1313

1414
# JAX
1515

16+
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
17+
18+
JAX is a high-performance scientific computing library that provides
19+
20+
* a NumPy-like interface that can automatically parallize across CPUs and GPUs,
21+
* a just-in-time compiler for accelerating a large range of numerical
22+
operations, and
23+
* automatic differentiation.
24+
25+
Increasingly, JAX also maintains and provides more specialized scientific
26+
computing routines, such as those originally found in SciPy.
27+
1628
In addition to what's in Anaconda, this lecture will need the following libraries:
1729

1830
```{code-cell} ipython3
@@ -21,28 +33,24 @@ In addition to what's in Anaconda, this lecture will need the following librarie
2133
!pip install jax quantecon
2234
```
2335

24-
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
25-
26-
Here we are focused on using JAX on the CPU, rather than on accelerators such as
27-
GPUs or TPUs.
28-
29-
This means we will only see a small amount of the possible benefits from using JAX.
30-
31-
However, JAX seamlessly handles transitions across different hardware platforms.
36+
```{admonition} GPU
37+
:class: warning
3238
33-
As a result, if you run this code on a machine with a GPU and a GPU-aware
34-
version of JAX installed, your code will be automatically accelerated and you
35-
will receive the full benefits.
39+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
3640
37-
For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html).
41+
Free GPUs are available on Google Colab.
42+
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
3843
44+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
45+
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
46+
```
3947

4048
## JAX as a NumPy Replacement
4149

42-
One of the attractive features of JAX is that, whenever possible, it conforms to
43-
the NumPy API for array operations.
50+
One of the attractive features of JAX is that, whenever possible, its array
51+
processing operations conform to the NumPy API.
4452

45-
This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.
53+
This means that, in many cases, we can use JAX is as a drop-in NumPy replacement.
4654

4755
Let's look at the similarities and differences between JAX and NumPy.
4856

@@ -523,16 +531,9 @@ with qe.Timer():
523531
jax.block_until_ready(y);
524532
```
525533

526-
If you are running this on a GPU the code will run much faster than its NumPy
527-
equivalent, which ran on the CPU.
528-
529-
Even if you are running on a machine with many CPUs, the second JAX run should
530-
be substantially faster with JAX.
531-
532-
Also, typically, the second run is faster than the first.
534+
On a GPU, this code runs much faster than its NumPy equivalent.
533535

534-
(This might not be noticable on the CPU but it should definitely be noticable on
535-
the GPU.)
536+
Also, typically, the second run is faster than the first due to JIT compilation.
536537

537538
This is because even built in functions like `jnp.cos` are JIT-compiled --- and the
538539
first run includes compile time.
@@ -634,8 +635,7 @@ with qe.Timer():
634635
jax.block_until_ready(y);
635636
```
636637

637-
The outcome is similar to the `cos` example --- JAX is faster, especially if you
638-
use a GPU and especially on the second run.
638+
The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation.
639639

640640
Moreover, with JAX, we have another trick up our sleeve:
641641

Collapse file

‎lectures/numpy_vs_numba_vs_jax.md‎

Copy file name to clipboardExpand all lines: lectures/numpy_vs_numba_vs_jax.md
+48-13Lines changed: 48 additions & 13 deletions
  • Display the source diff
  • Display the rich diff
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ tags: [hide-output]
4848
!pip install quantecon jax
4949
```
5050

51+
```{admonition} GPU
52+
:class: warning
53+
54+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
55+
56+
Free GPUs are available on Google Colab.
57+
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
58+
59+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
60+
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
61+
```
62+
5163
We will use the following imports.
5264

5365
```{code-cell} ipython3
@@ -317,7 +329,7 @@ with qe.Timer(precision=8):
317329
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
318330
```
319331

320-
Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.
332+
Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.
321333

322334
The compilation overhead is a one-time cost that pays off when the function is called repeatedly.
323335

@@ -370,23 +382,29 @@ with qe.Timer(precision=8):
370382
z_max.block_until_ready()
371383
```
372384

373-
The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
374-
we are using far less memory.
385+
By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory.
386+
387+
When run on a CPU, its runtime is similar to that of the meshgrid version.
375388

376-
In addition, `vmap` allows us to break vectorization up into stages, which is
377-
often easier to comprehend than the traditional approach.
389+
When run on a GPU, it is usually significantly faster.
378390

379-
This will become more obvious when we tackle larger problems.
391+
In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages.
392+
393+
This leads to code that is often easier to comprehend than traditional vectorized code.
394+
395+
We will investigate these ideas more when we tackle larger problems.
380396

381397

382398
### vmap version 2
383399

384400
We can be still more memory efficient using vmap.
385401

386-
While we avoided large input arrays in the preceding version,
402+
While we avoid large input arrays in the preceding version,
387403
we still create the large output array `f(x,y)` before we compute the max.
388404

389-
Let's use a slightly different approach that takes the max to the inside.
405+
Let's try a slightly different approach that takes the max to the inside.
406+
407+
Because of this change, we never compute the two-dimensional array `f(x,y)`.
390408

391409
```{code-cell} ipython3
392410
@jax.jit
@@ -399,23 +417,28 @@ def compute_max_vmap_v2(grid):
399417
return jnp.max(f_vec_max(grid))
400418
```
401419

402-
Let's try it
420+
Here
421+
422+
* `f_vec_x_max` computes the max along any given row
423+
* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel.
424+
425+
We apply this function to all rows and then take the max of the row maxes.
426+
427+
Let's try it.
403428

404429
```{code-cell} ipython3
405430
with qe.Timer(precision=8):
406431
z_max = compute_max_vmap_v2(grid).block_until_ready()
407432
```
408433

409-
410434
Let's run it again to eliminate compilation time:
411435

412436
```{code-cell} ipython3
413437
with qe.Timer(precision=8):
414438
z_max = compute_max_vmap_v2(grid).block_until_ready()
415439
```
416440

417-
We don't get much speed gain but we do save some memory.
418-
441+
If you are running this on a GPU, as we are, you should see another nontrivial speed gain.
419442

420443

421444
### Summary
@@ -497,7 +520,9 @@ Now let's create a JAX version using `lax.scan`:
497520
from jax import lax
498521
from functools import partial
499522
500-
@partial(jax.jit, static_argnums=(1,))
523+
cpu = jax.devices("cpu")[0]
524+
525+
@partial(jax.jit, static_argnums=(1,), device=cpu)
501526
def qm_jax(x0, n, α=4.0):
502527
def update(x, t):
503528
x_new = α * x * (1 - x)
@@ -509,6 +534,16 @@ def qm_jax(x0, n, α=4.0):
509534

510535
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
511536

537+
```{note}
538+
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.
539+
540+
The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.
541+
542+
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.
543+
544+
Curious readers can try removing this option to see how performance changes.
545+
```
546+
512547
Let's time it with the same parameters:
513548

514549
```{code-cell} ipython3
Collapse file

‎lectures/status.md‎

Copy file name to clipboardExpand all lines: lectures/status.md
+14Lines changed: 14 additions & 0 deletions
  • Display the source diff
  • Display the rich diff
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,18 @@ and the following package versions
3131
```{code-cell} ipython
3232
:tags: [hide-output]
3333
!conda list
34+
```
35+
36+
This lecture series has access to the following GPU
37+
38+
```{code-cell} ipython
39+
!nvidia-smi
40+
```
41+
42+
You can check the backend used by JAX using:
43+
44+
```{code-cell} ipython3
45+
import jax
46+
# Check if JAX is using GPU
47+
print(f"JAX backend: {jax.devices()[0].platform}")
3448
```
Collapse file

‎scripts/test-jax-install.py‎

Copy file name to clipboard
+21Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
devices = jax.devices()
5+
print(f"The available devices are: {devices}")
6+
7+
@jax.jit
8+
def matrix_multiply(a, b):
9+
return jnp.dot(a, b)
10+
11+
# Example usage:
12+
key = jax.random.PRNGKey(0)
13+
x = jax.random.normal(key, (1000, 1000))
14+
y = jax.random.normal(key, (1000, 1000))
15+
z = matrix_multiply(x, y)
16+
17+
# Now the function is JIT compiled and will likely run on GPU (if available)
18+
print(z)
19+
20+
devices = jax.devices()
21+
print(f"The available devices are: {devices}")

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.