Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2f103af

Browse filesBrowse files
authored
FIX: jax_intro timeout: use lax.fori_loop instead of Python for loop (#442)
* Fix jax_intro timeout: use lax.fori_loop instead of Python for loop The compute_call_price_jax function was timing out during cache.yml builds because JAX unrolls Python for loops during JIT compilation. With large arrays (M=10M), this causes excessive compilation time. Solution: Replace Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Fixes cell execution timeout in jax_intro.md * style: use jstac's fori_loop naming conventions - loop_body -> update - state -> loop_state - Added explicit new_loop_state and final_loop_state variables - More verbose but clearer for first-time fori_loop readers * style: loop_state -> initial_loop_state
1 parent 39297ba commit 2f103af
Copy full SHA for 2f103af

File tree

Expand file treeCollapse file tree

1 file changed

+16
-1
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

1 file changed

+16
-1
lines changed
Open diff view settings
Collapse file

‎lectures/jax_intro.md‎

Copy file name to clipboardExpand all lines: lectures/jax_intro.md
+16-1Lines changed: 16 additions & 1 deletion
  • Display the source diff
  • Display the rich diff
Original file line numberDiff line numberDiff line change
@@ -832,16 +832,31 @@ def compute_call_price_jax(β=β,
832832
833833
s = jnp.full(M, np.log(S0))
834834
h = jnp.full(M, h0)
835-
for t in range(n):
835+
836+
def update(i, loop_state):
837+
s, h, key = loop_state
836838
key, subkey = jax.random.split(key)
837839
Z = jax.random.normal(subkey, (2, M))
838840
s = s + μ + jnp.exp(h) * Z[0, :]
839841
h = ρ * h + ν * Z[1, :]
842+
new_loop_state = s, h, key
843+
return new_loop_state
844+
845+
initial_loop_state = s, h, key
846+
final_loop_state = jax.lax.fori_loop(0, n, update, initial_loop_state)
847+
s, h, key = final_loop_state
848+
840849
expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
841850
842851
return β**n * expectation
843852
```
844853

854+
```{note}
855+
We use `jax.lax.fori_loop` instead of a Python `for` loop.
856+
This allows JAX to compile the loop efficiently without unrolling it,
857+
which significantly reduces compilation time for large arrays.
858+
```
859+
845860
Let's run it once to compile it:
846861

847862
```{code-cell} ipython3

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.