[None][feat] Use a replay method for state rollback in Mamba-2 speculative decoding#13453
[None][feat] Use a replay method for state rollback in Mamba-2 speculative decoding#13453lucaslie merged 38 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom hnover-nv:mtp_state_computationhnover-nv/TensorRT-LLM:mtp_state_computationCopy head branch name to clipboard
Conversation
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
|
/bot run |
Key changes: - Tune BLOCK_SIZE_M=8, num_warps=1 for dstate<=128 (was M=4, warps=4). With (M, dstate) state in registers, fewer warps avoids local memory spills. - Compile-time bounded replay loop: for t in range(T) with if guard enables compiler unrolling since T is constexpr. - Indexed pointer addressing (t * stride) instead of pointer increment. - Fix test: new tokens use T=6 matching intermediate buffer (was T=1). Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…at M=8). Add benchmark CLI params: --block-size-m, --num-warps, --fast-forward-replay, --cb-output. ncu profiling: compute-bound 63% SM, 80 regs/thread, 32% occupancy, 11% DRAM. CB with L1 reload: ~40% slower than sequential (84 vs 59 us at b=128/k=0). Next: try TMA/shared memory for CB preloads. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Output phase restructured as three tl.dot calls (C@B^T, C@state^T, CB_scaled@x), following the ssd_chunk_scan convention. All dots use bf16 inputs with fp32 accumulation, matching context-phase Mamba2 precision. Key results (mtp=5, prev_k=2): batch=1: 14.3 -> 10.2 us (1.4x) batch=128: 157.7 -> 40.4 us (3.9x, M=32) or 45.1 us (3.5x, M=8) Also: benchmark CLI now supports M and warps sweeps (comma-separated values), test tolerance widened to match bf16 tl.dot precision. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…fered cache Two-kernel architecture: 1. Precompute kernel (grid: batch×nheads): CB=C@B^T via tl.dot, scale with decay/dt/causal mask, store CB_scaled+decay_vec. Also stores B, dt_proc, cumAdt to double-buffered cache for next step's replay. 2. Main kernel (grid: dim_tiles×batch×nheads): tl.dot replay using precomputed cumAdt (no softplus/cumsum), tl.dot output using precomputed CB_scaled. Cache redesign: separate tensors with double-buffering for precompute-written data (old_B, old_dt_proc, old_cumAdt), single-buffer for main-kernel-written data (old_x). cache_buf_idx per slot enables independent buffer management. Test uses random cache_buf_idx (0 or 1) with garbage in the other buffer to verify correct indexing. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Move gdc_launch_dependents() to start of precompute kernel (not end), allowing main kernel replay to overlap with precompute computation. Main kernel gdc_wait() gates only the output phase reads of cb_scaled/decay_vec. Sweep results with PDL: - Precompute: pW=2 best for batch>=128 (8% gain), num_stages no impact - Main kernel: M=32/W=2 best for batch=128 (43 us), num_stages no impact - batch=1/MTP=5: M=4/W=2 gives 8.2 us (best ever) - PDL hides precompute at small batch: batch=1 flat at 10.2 us regardless of MTP Add benchmark CLI: --pdl, --num-stages, --precompute-num-warps, --precompute-num-stages Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…evert dB_scaled experiment (neutral/worse) Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…fered cache Two-kernel architecture: 1. Precompute kernel (grid: batch×nheads): CB=C@B^T via tl.dot, scale with decay/dt/causal mask, store CB_scaled+decay_vec. Also stores B, dt_proc, cumAdt to double-buffered cache for next step's replay. 2. Main kernel (grid: dim_tiles×batch×nheads): tl.dot replay using precomputed cumAdt (no softplus/cumsum), tl.dot output using precomputed CB_scaled. Cache: separate tensors with double-buffering for precompute-written data (old_B, old_dt_proc, old_cumAdt), single-buffer for main-kernel-written (old_x). Layout: old_dt_proc/old_cumAdt use (cache, 2, nheads, T) for coalesced T access. Internal PDL: precompute calls gdc_launch_dependents() at start, main kernel overlaps state/x/C loads with precompute, then gdc_wait() before reading cb_scaled/decay_vec. Default on. Batch-adaptive tuning (B200, nheads=16, head_dim=64, dstate=128): batch=1: M=4/W=4 (8.2 µs) batch=2-4: M=4/W=2 (10.2 µs) batch=8-16: M=16/W=1 (12-16 µs) batch>=32: M=32/W=2 (18-233 µs) Max gap vs per-config optimal: 2.7%. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Cache manager (mamba_cache_manager.py): - Replace packed intermediate_ssm_update_inputs with separate tensors: old_x (single-buffered), old_B/old_dt_proc/old_cumAdt (double-buffered) - Add cache_buf_idx for double-buffer management, flipped in update_mamba_states - Update at_layer_idx to treat cache_buf_idx as shared (not per-layer) Mixer (mamba2_mixer.py): - Split conv1d + dt flow: MTP path calls conv1d directly (no dt fp32 conversion), non-MTP path keeps maybe_execute_in_parallel for flashinfer fp32 dt requirement - Update incremental kernel call to pass separate cache tensors - Our kernel handles bf16 dt natively (bias + softplus applied internally) Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Parameterize T in the test instead of hardcoding T=6. Bump atol from 0.5 to 1.0: the bf16 tl.dot truncation of fp32 intermediates introduces up to ~1 ULP error at output magnitude. Add detailed precision analysis documenting where and why our kernel differs from the baselines, prefill equivalence, and future options (TF32, philox). Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Key on total_heads (batch*nheads) and BLOCK_SIZE_T instead of batch alone.
Swept M={4..64}, W={1..4} across batch={1..512}, T={6,16,32}, TP={1,4,8}
on B200. At TP=1 batch=1 (total_heads=128), changes M=4,W=4 to M=8,W=1,
giving 46% speedup. At TP=8 batch<=4, small improvements (3-8%) from
W=4/W=2 to W=1. No regressions at any config.
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…conv1d benchmark mode Mixer: move .contiguous() calls on x_d/B_d/C_d into the non-MTP (flashinfer) branch only. Without this, the copy kernels between conv1d and our precompute kernel consume the PDL signal, breaking the external PDL chain. Our kernel handles non-contiguous inputs correctly via explicit stride parameters. Benchmark: add --with-conv1d mode that includes conv1d before the incremental kernel with production-matching tensor layouts (no .contiguous() copies). Realistic L2 flush: cold caches flushed, hot in_proj output kept warm. Add --external-pdl/--no-external-pdl to toggle the PDL chain independently. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Precompute kernel processes multiple heads per block, sharing the raw CB = C @ B^T computation. Two-loop structure: loop 1 computes per-head dt/cumAdt/decay before gdc_wait (overlaps with conv1d via external PDL). Loop 2 reloads from cache and scales raw_CB per-head after the wait. Adds batch-adaptive heuristic for H and pW (precompute warps): H=1 at total_heads<=128, H=2 at 256-512, H=4 at >=1024 pW=4 at BLOCK_SIZE_T<=16 for total_heads<=64 or >=512, else pW=1 Tuned on B200 for Nemotron-3-Super-120B, assumes external PDL with conv1d. Sweep results with conv1d + external PDL (T=6, TP=8, B200): b*h <= 128: ~0% (precomp hidden by overlap) b*h 512: ~1-2% b*h >= 1024: 3-7% savings on total span Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Move x/C loads after gdc_wait() in main kernel to fix data race when conv1d → precompute → main PDL chain is active. With external PDL, main kernel starts before conv1d completes; loads before gdc_wait() could read stale conv output. After gdc_wait(), precompute has completed (which transitively guarantees conv1d completed). Resweep M/W/pW/H heuristic on B200 with conv1d + chained PDL using production-matching tensor layout (contiguous batch*T viewed then transposed). Max gap vs per-config optimal: ≤0.2%. Key changes from previous heuristic: - H=1 dominates small/medium batch (production layout reduces HPB benefit) - M=64 at large batch (bh>=512) for bigger tl.dot tiles - W=2 at bh=128 for better warp parallelism with new strides - pW=2 replaces pW=4 at most configs Add TODO for PDL chain integration test. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Reduce BLOCK_N from 256 to 128 and explicitly set num_warps=4.
Tuned with production-matching tensor layout (contiguous (batch*T,
conv_dim) viewed as (batch, conv_dim, T), matching the mixer's
in_proj → view → transpose data flow).
nsys sweep across batch={1..64}, T={2..32}, BN={32..256}, W={1..8}
on B200 at TP=8: BN=128 W=4 has ≤12% max gap vs per-config optimal,
10-55% faster than BN=256 at T≥6.
Only affects the MTP spec decoding verification path
(causal_conv1d_update_triton in mamba2_mixer.py). Non-MTP decode
uses the separate causal_conv1d_update and is unaffected.
Also add _block_n/_num_warps/_num_stages override params for
benchmarking (not used in production).
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…ntless casts, update docs
Decouple internal/external PDL flags in precompute kernel:
Add LAUNCH_DEPENDENT_KERNELS constexpr gated by use_internal_pdl.
Previously gdc_launch_dependents() was unconditional — functionally
correct but now explicit. Precompute kernel now has two independent
PDL constexprs matching the wrapper's two flags:
LAUNCH_WITH_PDL: gdc_wait() for conv1d (external PDL)
LAUNCH_DEPENDENT_KERNELS: gdc_launch_dependents() for main (internal)
Rename abbreviated variables for readability:
ox_base → old_x_base, oB_base → old_B_base,
dv_base → decay_vec_base, cb_base → cb_scaled_base
Remove pointless fp32 roundtrips on bf16 data:
old_x_all and C_all were loaded bf16 → cast fp32 → cast bf16 for
tl.dot. Now kept as native dtype from load. Defensive .to(tl.bfloat16)
on tl.dot inputs retained (free no-op if already bf16, safe if not).
Update documentation:
Fix precompute grid comment (nheads // HEADS_PER_BLOCK).
Add PDL chain explanation and full argument docs to wrapper docstring.
Fix old_dt_proc/old_cumAdt shape docs (nheads, T) not (T, nheads).
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…sweep Remove dead intermediate_update_inputs/interm_work/xbc_out tensors. Extract _conv1d_split() helper shared by baseline and incremental with-conv1d paths. Hoist sweep parsing above prev_k loop. Replace 6-deep nested for loop with itertools.product. Remove no-op .contiguous() calls. Fix 'flashifner' typo, fix shadowed variable. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Add _stochastic_round_fp16x2 helper using PTX cvt.rs.f16x2.f32
instruction for stochastic rounding of fp32 state to fp16. Applied
at the state store in the main kernel when rand_seed is provided.
Matching the flashinfer API: rand_seed is a single-element int64
CUDA tensor (graph-compatible), philox_rounds defaults to 10.
Philox-4x32 amortization: tl.randint4x on quarter-sized dstate
offsets + tl.join + tl.reshape to reconstruct the full (M, dstate)
random tensor. 4x fewer PRNG rounds, saving ~1.2us at batch=1.
Benchmark: add fp16 state dtype, --philox-rounding flag.
Philox wired to both incremental and flashinfer baselines;
errors if used with triton baseline (unsupported).
Tests:
- Add fp16 to existing state_dtype parametrization (24 new tests)
- test_incremental_selective_state_update_philox: verifies
rounding vs no-rounding produce matching output and state
- test_philox_rounding_unbiased: statistical test over 2M
elements verifying stochastic rounding is unbiased (~33% of
elements round differently from deterministic, mean residual ~ 0)
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
|
PR_Github #45551 [ run ] completed with state
|
|
/bot run |
|
PR_Github #45601 [ run ] triggered by Bot. Commit: |
|
PR_Github #45601 [ run ] completed with state
|
cf68eef to
63730c5
Compare
|
/bot run |
|
PR_Github #45759 [ run ] triggered by Bot. Commit: |
mikeiovine
left a comment
There was a problem hiding this comment.
Stamping changes under pyexecutor/ on behalf of runtime/model devs. Did not review the rest, should be done by nemotron devs
) PR NVIDIA#13151 refactored the cache manager API: - update_mamba_states now takes state_indices: Tensor as a parameter (caller-provided) - _prepare_mamba_cache_blocks no longer eagerly builds state_indices_list / state_indices - get_state_indices is now a list comprehension; padding shares dummy slot - shutdown no longer resets self.state_indices Adapted our changes to the new API: - Kept PNAT=0 reset on cache miss in _prepare_mamba_cache_blocks - Kept replay-vs-legacy branch in update_mamba_states, reading the state_indices parameter instead of self.state_indices - _util.py: applied num_layers -> num_full_attention_layers rename Also, fix fast-path importing in the benchmark. Our addition of the sm version helper broke it. Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
63730c5 to
5b8d147
Compare
|
/bot run |
|
PR_Github #45759 [ run ] completed with state
|
|
PR_Github #45805 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #45805 [ run ] completed with state
|
|
PR_Github #45869 [ run ] triggered by Bot. Commit: |
|
PR_Github #45869 [ run ] completed with state
|
|
/bot run |
|
PR_Github #45975 [ run ] triggered by Bot. Commit: |
|
PR_Github #45975 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46041 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #46189 [ run ] triggered by Bot. Commit: |
|
PR_Github #46189 [ run ] completed with state |
…ative decoding (NVIDIA#13453) Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…pdate chain Wire the kernel-level PDL flags for the full conv1d → precompute → state_update PDL chain, matching the PyTorch backend (NVIDIA#13453): triton_backend_causal_conv.py (extend path): causal_conv1d_update(..., launch_dependent_kernels=True) Kernel emits gdc_launch_dependents() at block start. flashinfer_backend_mamba.py (replay path): replay_selective_state_update(..., launch_with_pdl=True) Precompute kernel: gdc_wait() for conv1d signal, gdc_launch_dependents() for state_update. State_update kernel: gdc_wait() for precompute. Both functions silently disable PDL on sm < 90 (Ampere/B200 only). Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Description
Mamba is a variant to attention that keeps a large but constant-sized state, that it destructively updates on each token. Whereas standard attention has an ever-growing KV cache.
In speculative decoding, draft tokens are produced from some source and then run at once through the target model. At the end, the model decides which tokens to accept, and any model state needs to be rolled back to be consistent.
For attention layers this is easy, just invalidate the now invalid kv cache entries by decrementing a counter. For mamba models, state was destructively updated, so it's not that simple.
Before this PR, at each state update, for each token we write to an intermediate cache the state that would be correct if that token was the last accepted token. Then, at step end, when we know which was the last accepted token, we copy that "winning" state to the per-request, main mamba cache.
In this PR, we instead adopt a "replay" method, one of the proposals in Snakes and Ladders and Mamba in the Llama. Our per-request cache now includes the mamba state from two steps back, all the inputs that went into our state update last step, and the number of tokens we accepted last step. Now, at state update, we first use the cache contents to advance the state to the correct end-of-last-step state given the # of accepted tokens, write that back, then use that and the new states to generate the output tokens. And we write this step's inputs to the cache, to be used as the "old" inputs next step. At step end, we just need to write to the cache the number of accepted tokens.
This saves us the end-of-step state copy, which is 8% of step runtime at batch size 256. It also lets us write substantially faster state update kernels, for a few reasons. First, we don't need to write all the intermediate states out. Second, we don't even need to materialize most of them, but can instead borrow tricks from the prefill kernel. We can go from state_{i-2} to state_{i-1 | num accepted at i-1} in one Tensor Core accelerated jump. And we can refactor the state update and output generation equations to produce this step's outputs without generating any state for this step, instead using more Tensor Core operations to produce the outputs directly. Even for 6 input tokens (draft length = 5), where many matrices need padding to be at least 16 in each dimension, this is a major win.
Run-time aside, the memory savings let us push the draft length much farther.
For Nemotron v3 Super BF16 on B200 with TP=8, here are both end-to-end run times and kernel microbenchmarks.
Also, because of the need to pad to 16 within the kernel, it is insensitive to longer draft lengths:
Some other optimizations, mostly for small batch: Remove needless format conversions and contiguous calls. These are worth a few % at low batch. Could do on the legacy flashinfer path with some small flashinfer changes, but as we hope that path will be replaced soon probably not worth it. Also, tune the grid for conv1d that we use on the MTP path.
Limitations
This PR only uses replay when using the python mamba cache manager, and only if advanced features like disaggreated serving, kv cache reuse, etc. are off. It also only supports the PyTorch backend, not AutoDeploy. Support will be extended in follow-on PRs. The legacy path is maintained for the mean time.
Test Coverage
Added a unit test for the replay function. Integration covered by existing Nemotron-H tests.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit
Release Notes
New Features
Performance
Tests
Documentation