[None][feat] Add Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding#12537
[None][feat] Add Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding#12537JadoTu merged 4 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom JadoTu:mamba2_tree_based_mtp_CUDA_kernelJadoTu/TensorRT-LLM:mamba2_tree_based_mtp_CUDA_kernelCopy head branch name to clipboard
Conversation
…ulative decoding Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
📝 WalkthroughWalkthroughThis pull request introduces a new Mamba2 MTP (Multi-Token Planning) SSM cache update operation. It adds a complete CUDA kernel implementation with templated specializations, PyTorch custom operator bindings, and Python API wrappers to enable multi-token state caching for Mamba2 models. Changes
Sequence Diagram(s)sequenceDiagram
participant PY as Python Code
participant TORCH as PyTorch Layer
participant OP as Custom Operator<br/>(mamba2MTPSSMCacheOp)
participant DISPATCH as Runtime Dispatch<br/>(mamba2MTPSSMCache.cu)
participant KERNEL as CUDA Kernel<br/>(mamba2MTPSSMCacheKernel)
PY->>TORCH: Call selective_state_update_mtp_ssm_cache_trtllm()
TORCH->>OP: torch.ops.trtllm.mamba2_mtp_ssm_cache_update(...)
OP->>OP: Validate tensor shapes & device placement
OP->>OP: Extract raw pointers & optional parameters
OP->>OP: Map PyTorch dtypes → Mamba2Dtype
OP->>DISPATCH: invokeMamba2MTPSSMCacheUpdate(params, stream)
DISPATCH->>DISPATCH: Validate head_dim divisibility
DISPATCH->>DISPATCH: Select VEC_SIZE from ssm_dim
DISPATCH->>KERNEL: Launch kernel with grid/block/stream
KERNEL->>KERNEL: Process per (head_id, bs_id) pair
KERNEL->>KERNEL: Load state rows & iterate cache_steps
KERNEL->>KERNEL: Compute dt dynamics & accumulate B/C
KERNEL->>KERNEL: Optional D gating & z modulation
KERNEL->>KERNEL: Write results to intermediate_states & state
KERNEL-->>PY: Return updated tensors
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu`:
- Around line 56-65: Add a precondition check in invokeMamba2MTPSSMCacheUpdate
to validate params.ngroups and the head/group ratio before dispatch: ensure
params.ngroups > 0 and params.nheads % params.ngroups == 0 (so
heads_groups_ratio = nheads/ngroups is integral and non-zero) to prevent
division-by-zero and incorrect grouping used by mamba2MTPSSMCacheKernel.cuh; use
the existing TLLM_CHECK_WITH_INFO pattern (same as the head_dim check) to fail
fast with a clear message referencing ngroups and nheads before calling
MTP_DISPATCH_VEC_SIZE and launchMamba2MTPSSMCacheKernel.
In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h`:
- Around line 35-79: Update the Mamba2MTPSSMCacheParams struct documentation to
reflect the actual tensor layouts used by mamba2MTPSSMCacheKernel: change dt, A,
D, and dt_bias comments to indicate they are indexed only by head (no
head_dim/ssm_dim axes) and correct their shapes accordingly (e.g., dt: [bs,
cache_steps, nheads], A: [nheads], D/dt_bias: [nheads] or similarly 1D per head
as used in kernel), and convert all public comments to Doxygen C++ style (use
//! for single-line comments and //!< for member annotations) including the
prototype for invokeMamba2MTPSSMCacheUpdate; reference struct name
Mamba2MTPSSMCacheParams and the kernel usage locations
(mamba2MTPSSMCacheKernel.cuh lines mentioned) so callers provision tensors with
the right layouts.
In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh`:
- Around line 372-383: The kernel is restoring parent state from
retrieve_parent_token even when parent_step_idx == t or > t (future/unwritten
entries); update the guard in the RETRIEVE_PARENT_TOKEN block so you only load
when parent_step_idx >= 0 && parent_step_idx < t (not < cache_steps), e.g. in
the block around retrieve_parent_token[bs_id * cache_steps + t] before calling
mtp_load_vec_to_float for state_4_a/state_4_b with inter_base_a/inter_base_b and
stride_nheads_hdim_ssm_dim ensure the check uses t as the upper bound to only
restore from already-materialized steps.
- Around line 113-116: The current mtp_softplus uses __logf(1 +
__expf(dt_value)) which overflows for large dt_value (causing
xdt_val_a/xdt_val_b to become inf); replace it with a numerically stable branch:
for dt_value > 0 return dt_value + __logf(1.f + __expf(-dt_value)), otherwise
return __logf(1.f + __expf(dt_value)) so exponentials are computed on non-large
positive inputs and downstream state/output corruption is avoided; update the
mtp_softplus device function accordingly (it’s the function to change).
In `@cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp`:
- Around line 78-82: B.size(2) can be zero, causing undefined host behavior when
evaluating nheads % ngroups; add an explicit check that rejects ngroups == 0
before performing the modulo. Concretely, after computing int const ngroups =
B.size(2); add a TORCH_CHECK(ngroups > 0, "ngroups must be > 0") (or similar)
and then keep the existing TORCH_CHECK(nheads % ngroups == 0, "unsupported pair
of nheads and ngroups") so the modulo is only executed on a non-zero ngroups.
- Around line 50-63: Add a host-side validation that when ssm_batch_indices is
not provided (i.e., kernel will use direct bs_id indexing) the first dimension
of ssm is at least the batch size of x: check ssm.size(0) >= x.size(0) and raise
a TORCH_CHECK with a clear message if not; locate this guard near the existing
tensor shape/device checks around ssm, x, and intermediate_states in
mamba2MTPSSMCacheOp.cpp so it runs before launching the kernel and prevents
out-of-bounds access when ssm_batch_indices is null.
- Around line 209-210: The code calls at::cuda::getCurrentCUDAStream() without
setting the device, which can pick the thread-local device instead of the device
associated with params; fix by including <c10/cuda/CUDAGuard.h> and creating a
c10::cuda::CUDAGuard guard(params.device()) (or the appropriate device accessor
from params) immediately before calling at::cuda::getCurrentCUDAStream(), then
get the stream and call tk::invokeMamba2MTPSSMCacheUpdate(params, stream) so the
kernel is launched on the correct device.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 73658ce3-902d-4e56-bd1e-53fc964d780d
📒 Files selected for processing (12)
cpp/tensorrt_llm/CMakeLists.txtcpp/tensorrt_llm/kernels/CMakeLists.txtcpp/tensorrt_llm/kernels/mamba2MTPSSMCache/CMakeLists.txtcpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cucpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.hcpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuhcpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec16.cucpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec4.cucpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec8.cucpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpptensorrt_llm/_torch/modules/mamba/selective_state_update.py
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
|
/bot run |
|
PR_Github #40687 [ run ] triggered by Bot. Commit: |
|
PR_Github #40687 [ run ] completed with state
|
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
|
/bot run |
|
PR_Github #40860 [ run ] triggered by Bot. Commit: |
|
PR_Github #40860 [ run ] completed with state
|
|
/bot run |
|
PR_Github #41057 [ run ] triggered by Bot. Commit: |
|
PR_Github #41057 [ run ] completed with state |
…ulative decoding (NVIDIA#12537) Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
Summary by CodeRabbit
Release Notes
Description
This commit adds a CUDA kernel (mamba2_mtp_ssm_cache_update) that:
Supports tree-based speculative decoding via retrieve_parent_token — at each step, the kernel can restore the SSM state from an arbitrary parent token's cached entry, enabling non-linear (tree-structured) draft token verification.
Flashinfer kernel now doesn't support tree-based speculative decoding. Achieves ~10% performance improvement over the FlashInfer kernel even without tree-based speculative decoding.
Flashinfer kernel
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.