Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[None][feat] Use a replay method for state rollback in Mamba-2 speculative decoding#13453

Merged
lucaslie merged 38 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
hnover-nv:mtp_state_computationhnover-nv/TensorRT-LLM:mtp_state_computationCopy head branch name to clipboard
Apr 30, 2026
Merged

[None][feat] Use a replay method for state rollback in Mamba-2 speculative decoding#13453
lucaslie merged 38 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
hnover-nv:mtp_state_computationhnover-nv/TensorRT-LLM:mtp_state_computationCopy head branch name to clipboard

Conversation

@hnover-nv
Copy link
Copy Markdown
Collaborator

@hnover-nv hnover-nv commented Apr 24, 2026

Description

Mamba is a variant to attention that keeps a large but constant-sized state, that it destructively updates on each token. Whereas standard attention has an ever-growing KV cache.

In speculative decoding, draft tokens are produced from some source and then run at once through the target model. At the end, the model decides which tokens to accept, and any model state needs to be rolled back to be consistent.

For attention layers this is easy, just invalidate the now invalid kv cache entries by decrementing a counter. For mamba models, state was destructively updated, so it's not that simple.

Before this PR, at each state update, for each token we write to an intermediate cache the state that would be correct if that token was the last accepted token. Then, at step end, when we know which was the last accepted token, we copy that "winning" state to the per-request, main mamba cache.

In this PR, we instead adopt a "replay" method, one of the proposals in Snakes and Ladders and Mamba in the Llama. Our per-request cache now includes the mamba state from two steps back, all the inputs that went into our state update last step, and the number of tokens we accepted last step. Now, at state update, we first use the cache contents to advance the state to the correct end-of-last-step state given the # of accepted tokens, write that back, then use that and the new states to generate the output tokens. And we write this step's inputs to the cache, to be used as the "old" inputs next step. At step end, we just need to write to the cache the number of accepted tokens.

This saves us the end-of-step state copy, which is 8% of step runtime at batch size 256. It also lets us write substantially faster state update kernels, for a few reasons. First, we don't need to write all the intermediate states out. Second, we don't even need to materialize most of them, but can instead borrow tricks from the prefill kernel. We can go from state_{i-2} to state_{i-1 | num accepted at i-1} in one Tensor Core accelerated jump. And we can refactor the state update and output generation equations to produce this step's outputs without generating any state for this step, instead using more Tensor Core operations to produce the outputs directly. Even for 6 input tokens (draft length = 5), where many matrices need padding to be at least 16 in each dimension, this is a major win.

Run-time aside, the memory savings let us push the draft length much farther.

For Nemotron v3 Super BF16 on B200 with TP=8, here are both end-to-end run times and kernel microbenchmarks.

e2e_speedup_bf16 speedup_slide_bf16_tp8

Also, because of the need to pad to 16 within the kernel, it is insensitive to longer draft lengths:

mtp_normalized_side_by_side

Some other optimizations, mostly for small batch: Remove needless format conversions and contiguous calls. These are worth a few % at low batch. Could do on the legacy flashinfer path with some small flashinfer changes, but as we hope that path will be replaced soon probably not worth it. Also, tune the grid for conv1d that we use on the MTP path.

Limitations

This PR only uses replay when using the python mamba cache manager, and only if advanced features like disaggreated serving, kv cache reuse, etc. are off. It also only supports the PyTorch backend, not AutoDeploy. Support will be extended in follow-on PRs. The legacy path is maintained for the mean time.

Test Coverage

Added a unit test for the replay function. Integration covered by existing Nemotron-H tests.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced replay-based selective state update optimization for Mamba models
    • Added hardware capability detection for GPU SM versions with dynamic feature gating
  • Performance

    • Optimized speculative decoding path for improved SSM state updates
  • Tests

    • Added comprehensive unit tests and benchmarks for replay state update functionality
  • Documentation

    • Updated docstrings to reflect hardware-aware feature gating behavior

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

Key changes:
- Tune BLOCK_SIZE_M=8, num_warps=1 for dstate<=128 (was M=4, warps=4).
  With (M, dstate) state in registers, fewer warps avoids local memory spills.
- Compile-time bounded replay loop: for t in range(T) with if guard
  enables compiler unrolling since T is constexpr.
- Indexed pointer addressing (t * stride) instead of pointer increment.
- Fix test: new tokens use T=6 matching intermediate buffer (was T=1).


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…at M=8).

Add benchmark CLI params: --block-size-m, --num-warps, --fast-forward-replay, --cb-output.
ncu profiling: compute-bound 63% SM, 80 regs/thread, 32% occupancy, 11% DRAM.
CB with L1 reload: ~40% slower than sequential (84 vs 59 us at b=128/k=0).
Next: try TMA/shared memory for CB preloads.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Output phase restructured as three tl.dot calls (C@B^T, C@state^T, CB_scaled@x),
following the ssd_chunk_scan convention. All dots use bf16 inputs with fp32
accumulation, matching context-phase Mamba2 precision.

Key results (mtp=5, prev_k=2):
  batch=1:   14.3 -> 10.2 us (1.4x)
  batch=128: 157.7 -> 40.4 us (3.9x, M=32) or 45.1 us (3.5x, M=8)

Also: benchmark CLI now supports M and warps sweeps (comma-separated values),
test tolerance widened to match bf16 tl.dot precision.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…fered cache

Two-kernel architecture:
1. Precompute kernel (grid: batch×nheads): CB=C@B^T via tl.dot, scale with
   decay/dt/causal mask, store CB_scaled+decay_vec. Also stores B, dt_proc,
   cumAdt to double-buffered cache for next step's replay.
2. Main kernel (grid: dim_tiles×batch×nheads): tl.dot replay using precomputed
   cumAdt (no softplus/cumsum), tl.dot output using precomputed CB_scaled.

Cache redesign: separate tensors with double-buffering for precompute-written
data (old_B, old_dt_proc, old_cumAdt), single-buffer for main-kernel-written
data (old_x). cache_buf_idx per slot enables independent buffer management.

Test uses random cache_buf_idx (0 or 1) with garbage in the other buffer
to verify correct indexing.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Move gdc_launch_dependents() to start of precompute kernel (not end), allowing
main kernel replay to overlap with precompute computation. Main kernel gdc_wait()
gates only the output phase reads of cb_scaled/decay_vec.

Sweep results with PDL:
- Precompute: pW=2 best for batch>=128 (8% gain), num_stages no impact
- Main kernel: M=32/W=2 best for batch=128 (43 us), num_stages no impact
- batch=1/MTP=5: M=4/W=2 gives 8.2 us (best ever)
- PDL hides precompute at small batch: batch=1 flat at 10.2 us regardless of MTP

Add benchmark CLI: --pdl, --num-stages, --precompute-num-warps, --precompute-num-stages


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…evert dB_scaled experiment (neutral/worse)

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…fered cache

Two-kernel architecture:
1. Precompute kernel (grid: batch×nheads): CB=C@B^T via tl.dot, scale with
   decay/dt/causal mask, store CB_scaled+decay_vec. Also stores B, dt_proc,
   cumAdt to double-buffered cache for next step's replay.
2. Main kernel (grid: dim_tiles×batch×nheads): tl.dot replay using precomputed
   cumAdt (no softplus/cumsum), tl.dot output using precomputed CB_scaled.

Cache: separate tensors with double-buffering for precompute-written data
(old_B, old_dt_proc, old_cumAdt), single-buffer for main-kernel-written (old_x).
Layout: old_dt_proc/old_cumAdt use (cache, 2, nheads, T) for coalesced T access.

Internal PDL: precompute calls gdc_launch_dependents() at start, main kernel
overlaps state/x/C loads with precompute, then gdc_wait() before reading
cb_scaled/decay_vec. Default on.

Batch-adaptive tuning (B200, nheads=16, head_dim=64, dstate=128):
  batch=1:    M=4/W=4  (8.2 µs)
  batch=2-4:  M=4/W=2  (10.2 µs)
  batch=8-16: M=16/W=1 (12-16 µs)
  batch>=32:  M=32/W=2 (18-233 µs)
Max gap vs per-config optimal: 2.7%.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Cache manager (mamba_cache_manager.py):
- Replace packed intermediate_ssm_update_inputs with separate tensors:
  old_x (single-buffered), old_B/old_dt_proc/old_cumAdt (double-buffered)
- Add cache_buf_idx for double-buffer management, flipped in update_mamba_states
- Update at_layer_idx to treat cache_buf_idx as shared (not per-layer)

Mixer (mamba2_mixer.py):
- Split conv1d + dt flow: MTP path calls conv1d directly (no dt fp32 conversion),
  non-MTP path keeps maybe_execute_in_parallel for flashinfer fp32 dt requirement
- Update incremental kernel call to pass separate cache tensors
- Our kernel handles bf16 dt natively (bias + softplus applied internally)


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Parameterize T in the test instead of hardcoding T=6. Bump atol from
0.5 to 1.0: the bf16 tl.dot truncation of fp32 intermediates introduces
up to ~1 ULP error at output magnitude. Add detailed precision analysis
documenting where and why our kernel differs from the baselines, prefill
equivalence, and future options (TF32, philox).


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Key on total_heads (batch*nheads) and BLOCK_SIZE_T instead of batch alone.
Swept M={4..64}, W={1..4} across batch={1..512}, T={6,16,32}, TP={1,4,8}
on B200. At TP=1 batch=1 (total_heads=128), changes M=4,W=4 to M=8,W=1,
giving 46% speedup. At TP=8 batch<=4, small improvements (3-8%) from
W=4/W=2 to W=1. No regressions at any config.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…conv1d benchmark mode

Mixer: move .contiguous() calls on x_d/B_d/C_d into the non-MTP (flashinfer)
branch only. Without this, the copy kernels between conv1d and our precompute
kernel consume the PDL signal, breaking the external PDL chain. Our kernel
handles non-contiguous inputs correctly via explicit stride parameters.

Benchmark: add --with-conv1d mode that includes conv1d before the incremental
kernel with production-matching tensor layouts (no .contiguous() copies).
Realistic L2 flush: cold caches flushed, hot in_proj output kept warm.
Add --external-pdl/--no-external-pdl to toggle the PDL chain independently.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Precompute kernel processes multiple heads per block, sharing the raw
CB = C @ B^T computation. Two-loop structure: loop 1 computes per-head
dt/cumAdt/decay before gdc_wait (overlaps with conv1d via external PDL).
Loop 2 reloads from cache and scales raw_CB per-head after the wait.

Adds batch-adaptive heuristic for H and pW (precompute warps):
  H=1 at total_heads<=128, H=2 at 256-512, H=4 at >=1024
  pW=4 at BLOCK_SIZE_T<=16 for total_heads<=64 or >=512, else pW=1
Tuned on B200 for Nemotron-3-Super-120B, assumes external PDL with conv1d.

Sweep results with conv1d + external PDL (T=6, TP=8, B200):
  b*h <= 128: ~0% (precomp hidden by overlap)
  b*h 512: ~1-2%
  b*h >= 1024: 3-7% savings on total span


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Move x/C loads after gdc_wait() in main kernel to fix data race when
conv1d → precompute → main PDL chain is active. With external PDL,
main kernel starts before conv1d completes; loads before gdc_wait()
could read stale conv output. After gdc_wait(), precompute has
completed (which transitively guarantees conv1d completed).

Resweep M/W/pW/H heuristic on B200 with conv1d + chained PDL using
production-matching tensor layout (contiguous batch*T viewed then
transposed). Max gap vs per-config optimal: ≤0.2%.

Key changes from previous heuristic:
  - H=1 dominates small/medium batch (production layout reduces HPB benefit)
  - M=64 at large batch (bh>=512) for bigger tl.dot tiles
  - W=2 at bh=128 for better warp parallelism with new strides
  - pW=2 replaces pW=4 at most configs

Add TODO for PDL chain integration test.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Reduce BLOCK_N from 256 to 128 and explicitly set num_warps=4.
Tuned with production-matching tensor layout (contiguous (batch*T,
conv_dim) viewed as (batch, conv_dim, T), matching the mixer's
in_proj → view → transpose data flow).

nsys sweep across batch={1..64}, T={2..32}, BN={32..256}, W={1..8}
on B200 at TP=8: BN=128 W=4 has ≤12% max gap vs per-config optimal,
10-55% faster than BN=256 at T≥6.

Only affects the MTP spec decoding verification path
(causal_conv1d_update_triton in mamba2_mixer.py). Non-MTP decode
uses the separate causal_conv1d_update and is unaffected.

Also add _block_n/_num_warps/_num_stages override params for
benchmarking (not used in production).


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…ntless casts, update docs

Decouple internal/external PDL flags in precompute kernel:
  Add LAUNCH_DEPENDENT_KERNELS constexpr gated by use_internal_pdl.
  Previously gdc_launch_dependents() was unconditional — functionally
  correct but now explicit. Precompute kernel now has two independent
  PDL constexprs matching the wrapper's two flags:
    LAUNCH_WITH_PDL: gdc_wait() for conv1d (external PDL)
    LAUNCH_DEPENDENT_KERNELS: gdc_launch_dependents() for main (internal)

Rename abbreviated variables for readability:
  ox_base → old_x_base, oB_base → old_B_base,
  dv_base → decay_vec_base, cb_base → cb_scaled_base

Remove pointless fp32 roundtrips on bf16 data:
  old_x_all and C_all were loaded bf16 → cast fp32 → cast bf16 for
  tl.dot. Now kept as native dtype from load. Defensive .to(tl.bfloat16)
  on tl.dot inputs retained (free no-op if already bf16, safe if not).

Update documentation:
  Fix precompute grid comment (nheads // HEADS_PER_BLOCK).
  Add PDL chain explanation and full argument docs to wrapper docstring.
  Fix old_dt_proc/old_cumAdt shape docs (nheads, T) not (T, nheads).


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
…sweep

Remove dead intermediate_update_inputs/interm_work/xbc_out tensors.
Extract _conv1d_split() helper shared by baseline and incremental
with-conv1d paths. Hoist sweep parsing above prev_k loop. Replace
6-deep nested for loop with itertools.product. Remove no-op
.contiguous() calls. Fix 'flashifner' typo, fix shadowed variable.


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Add _stochastic_round_fp16x2 helper using PTX cvt.rs.f16x2.f32
instruction for stochastic rounding of fp32 state to fp16. Applied
at the state store in the main kernel when rand_seed is provided.
Matching the flashinfer API: rand_seed is a single-element int64
CUDA tensor (graph-compatible), philox_rounds defaults to 10.

Philox-4x32 amortization: tl.randint4x on quarter-sized dstate
offsets + tl.join + tl.reshape to reconstruct the full (M, dstate)
random tensor. 4x fewer PRNG rounds, saving ~1.2us at batch=1.

Benchmark: add fp16 state dtype, --philox-rounding flag.
Philox wired to both incremental and flashinfer baselines;
errors if used with triton baseline (unsupported).

Tests:
  - Add fp16 to existing state_dtype parametrization (24 new tests)
  - test_incremental_selective_state_update_philox: verifies
    rounding vs no-rounding produce matching output and state
  - test_philox_rounding_unbiased: statistical test over 2M
    elements verifying stochastic rounding is unbiased (~33% of
    elements round differently from deterministic, mean residual ~ 0)


Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45551 [ run ] completed with state SUCCESS. Commit: 3650984
/LLM/main/L0_MergeRequest_PR pipeline #35769 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45601 [ run ] triggered by Bot. Commit: 3650984 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45601 [ run ] completed with state SUCCESS. Commit: 3650984
/LLM/main/L0_MergeRequest_PR pipeline #35817 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hnover-nv hnover-nv force-pushed the mtp_state_computation branch from cf68eef to 63730c5 Compare April 27, 2026 17:32
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45759 [ run ] triggered by Bot. Commit: 63730c5 Link to invocation

Copy link
Copy Markdown
Collaborator

@mikeiovine mikeiovine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping changes under pyexecutor/ on behalf of runtime/model devs. Did not review the rest, should be done by nemotron devs

)

PR NVIDIA#13151 refactored the cache manager API:
- update_mamba_states now takes state_indices: Tensor as a parameter (caller-provided)
- _prepare_mamba_cache_blocks no longer eagerly builds state_indices_list / state_indices
- get_state_indices is now a list comprehension; padding shares dummy slot
- shutdown no longer resets self.state_indices

Adapted our changes to the new API:
- Kept PNAT=0 reset on cache miss in _prepare_mamba_cache_blocks
- Kept replay-vs-legacy branch in update_mamba_states, reading the
  state_indices parameter instead of self.state_indices
- _util.py: applied num_layers -> num_full_attention_layers rename

Also, fix fast-path importing in the benchmark.  Our addition of the sm version helper broke it.

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
@hnover-nv hnover-nv force-pushed the mtp_state_computation branch from 63730c5 to 5b8d147 Compare April 28, 2026 00:55
@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45759 [ run ] completed with state SUCCESS. Commit: 63730c5
/LLM/main/L0_MergeRequest_PR pipeline #35952 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45805 [ run ] triggered by Bot. Commit: 5b8d147 Link to invocation

@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45805 [ run ] completed with state SUCCESS. Commit: 5b8d147
/LLM/main/L0_MergeRequest_PR pipeline #35995 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45869 [ run ] triggered by Bot. Commit: 5b8d147 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45869 [ run ] completed with state FAILURE. Commit: 5b8d147
/LLM/main/L0_MergeRequest_PR pipeline #36045 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45975 [ run ] triggered by Bot. Commit: 5b8d147 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45975 [ run ] completed with state SUCCESS. Commit: 5b8d147
/LLM/main/L0_MergeRequest_PR pipeline #36128 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46041 [ run ] triggered by Bot. Commit: 5b8d147 Link to invocation

@hnover-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46189 [ run ] triggered by Bot. Commit: 5b8d147 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46189 [ run ] completed with state SUCCESS. Commit: 5b8d147
/LLM/main/L0_MergeRequest_PR pipeline #36305 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@lucaslie lucaslie merged commit 5221c7b into NVIDIA:main Apr 30, 2026
5 checks passed
evezhier pushed a commit to evezhier/TensorRT-LLM that referenced this pull request May 4, 2026
…ative decoding (NVIDIA#13453)

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
galagam added a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request May 8, 2026
…pdate chain

Wire the kernel-level PDL flags for the full conv1d → precompute → state_update
PDL chain, matching the PyTorch backend (NVIDIA#13453):

  triton_backend_causal_conv.py (extend path):
    causal_conv1d_update(..., launch_dependent_kernels=True)
    Kernel emits gdc_launch_dependents() at block start.

  flashinfer_backend_mamba.py (replay path):
    replay_selective_state_update(..., launch_with_pdl=True)
    Precompute kernel: gdc_wait() for conv1d signal, gdc_launch_dependents()
    for state_update.  State_update kernel: gdc_wait() for precompute.

Both functions silently disable PDL on sm < 90 (Ampere/B200 only).

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.