Commit 2f103af
authored
FIX: jax_intro timeout: use lax.fori_loop instead of Python for loop (#442)
* Fix jax_intro timeout: use lax.fori_loop instead of Python for loop
The compute_call_price_jax function was timing out during cache.yml builds
because JAX unrolls Python for loops during JIT compilation. With large
arrays (M=10M), this causes excessive compilation time.
Solution: Replace Python for loop with jax.lax.fori_loop, which compiles
the loop efficiently without unrolling.
Fixes cell execution timeout in jax_intro.md
* style: use jstac's fori_loop naming conventions
- loop_body -> update
- state -> loop_state
- Added explicit new_loop_state and final_loop_state variables
- More verbose but clearer for first-time fori_loop readers
* style: loop_state -> initial_loop_state1 parent 39297ba commit 2f103afCopy full SHA for 2f103af
File tree
Expand file treeCollapse file tree
1 file changed
+16
-1
lines changedOpen diff view settings
Filter options
- lectures
Expand file treeCollapse file tree
1 file changed
+16
-1
lines changedOpen diff view settings
Collapse file
+16-1Lines changed: 16 additions & 1 deletion
- Display the source diff
- Display the rich diff
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| ||
832 | 832 | |
833 | 833 | |
834 | 834 | |
835 | | - |
| 835 | + |
| 836 | + |
| 837 | + |
836 | 838 | |
837 | 839 | |
838 | 840 | |
839 | 841 | |
| 842 | + |
| 843 | + |
| 844 | + |
| 845 | + |
| 846 | + |
| 847 | + |
| 848 | + |
840 | 849 | |
841 | 850 | |
842 | 851 | |
843 | 852 | |
844 | 853 | |
| 854 | + |
| 855 | + |
| 856 | + |
| 857 | + |
| 858 | + |
| 859 | + |
845 | 860 | |
846 | 861 | |
847 | 862 | |
|
0 commit comments