Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[None][feat] Add Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding#12537

Merged
JadoTu merged 4 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
JadoTu:mamba2_tree_based_mtp_CUDA_kernelJadoTu/TensorRT-LLM:mamba2_tree_based_mtp_CUDA_kernelCopy head branch name to clipboard
Apr 1, 2026
Merged

[None][feat] Add Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding#12537
JadoTu merged 4 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
JadoTu:mamba2_tree_based_mtp_CUDA_kernelJadoTu/TensorRT-LLM:mamba2_tree_based_mtp_CUDA_kernelCopy head branch name to clipboard

Conversation

@JadoTu
Copy link
Copy Markdown
Collaborator

@JadoTu JadoTu commented Mar 25, 2026

Summary by CodeRabbit

Release Notes

  • New Features
    • Added Mamba2 MTP SSM cache update operation with CUDA acceleration and PyTorch integration for enhanced multi-token state space model computations.
    • Supports float32, float16, and bfloat16 precision formats with configurable gating and caching options.

Description

This commit adds a CUDA kernel (mamba2_mtp_ssm_cache_update) that:

  1. Supports tree-based speculative decoding via retrieve_parent_token — at each step, the kernel can restore the SSM state from an arbitrary parent token's cached entry, enabling non-linear (tree-structured) draft token verification.

  2. Flashinfer kernel now doesn't support tree-based speculative decoding. Achieves ~10% performance improvement over the FlashInfer kernel even without tree-based speculative decoding.
    Flashinfer kernel

image PR's kernel image
  1. With tree-based speculative decoding, this PR's kernel is 2x faster than the initial triton kernel (183 us v.s. 369 us).

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

…ulative decoding

Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
@JadoTu JadoTu requested review from a team as code owners March 25, 2026 09:55
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 25, 2026

📝 Walkthrough

Walkthrough

This pull request introduces a new Mamba2 MTP (Multi-Token Planning) SSM cache update operation. It adds a complete CUDA kernel implementation with templated specializations, PyTorch custom operator bindings, and Python API wrappers to enable multi-token state caching for Mamba2 models.

Changes

Cohort / File(s) Summary
CUDA Kernel Core
cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h, mamba2MTPSSMCacheKernel.cuh
New header defining Mamba2Dtype enum, Mamba2MTPSSMCacheParams struct, and kernel function declarations. Core kernel implements per-head-batch SSM state update with optional parent-token retrieval, fused B/C accumulation, and data type dispatch across bfloat16, float16, and float32.
Kernel Specializations
cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu, mamba2MTPSSMCacheVec4.cu, mamba2MTPSSMCacheVec8.cu, mamba2MTPSSMCacheVec16.cu
Runtime dispatch function validates head_dim divisibility and selects compile-time vector size (4/8/16) based on ssm_dim (128/256/512). Three translation units explicitly instantiate launchMamba2MTPSSMCacheKernel templates for each vector size.
PyTorch Operator
cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp
New Torch custom operator mamba2_mtp_ssm_cache_update validates tensor shapes, device placement, and stride requirements for all inputs (including optional D, z, dt_bias, batch indices tensors). Constructs kernel parameters, handles dtype mapping, and dispatches to invokeMamba2MTPSSMCacheUpdate.
Python API
tensorrt_llm/_torch/modules/mamba/selective_state_update.py
Added selective_state_update_mtp_ssm_cache_trtllm() wrapper function forwarding inputs to the custom Torch operator with multi-token caching parameters and intermediate state buffering.
Build Configuration
cpp/tensorrt_llm/CMakeLists.txt, cpp/tensorrt_llm/kernels/CMakeLists.txt, cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/CMakeLists.txt, cpp/tensorrt_llm/thop/CMakeLists.txt
Extended main build system to link mamba2_mtp_ssm_cache_src library. Added new subdirectory and excluded its .cu files from main globbing. New CMake file defines object library with position-independent code and device symbol resolution enabled.

Sequence Diagram(s)

sequenceDiagram
    participant PY as Python Code
    participant TORCH as PyTorch Layer
    participant OP as Custom Operator<br/>(mamba2MTPSSMCacheOp)
    participant DISPATCH as Runtime Dispatch<br/>(mamba2MTPSSMCache.cu)
    participant KERNEL as CUDA Kernel<br/>(mamba2MTPSSMCacheKernel)

    PY->>TORCH: Call selective_state_update_mtp_ssm_cache_trtllm()
    TORCH->>OP: torch.ops.trtllm.mamba2_mtp_ssm_cache_update(...)
    OP->>OP: Validate tensor shapes & device placement
    OP->>OP: Extract raw pointers & optional parameters
    OP->>OP: Map PyTorch dtypes → Mamba2Dtype
    OP->>DISPATCH: invokeMamba2MTPSSMCacheUpdate(params, stream)
    DISPATCH->>DISPATCH: Validate head_dim divisibility
    DISPATCH->>DISPATCH: Select VEC_SIZE from ssm_dim
    DISPATCH->>KERNEL: Launch kernel with grid/block/stream
    KERNEL->>KERNEL: Process per (head_id, bs_id) pair
    KERNEL->>KERNEL: Load state rows & iterate cache_steps
    KERNEL->>KERNEL: Compute dt dynamics & accumulate B/C
    KERNEL->>KERNEL: Optional D gating & z modulation
    KERNEL->>KERNEL: Write results to intermediate_states & state
    KERNEL-->>PY: Return updated tensors
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description addresses key features and performance metrics but lacks clarity in critical sections required by the template. Complete the 'Test Coverage' section with specific test names/locations; clarify which checklist items were actually completed versus marked as assumed complete.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding a Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding, which aligns with the substantial code additions across multiple kernel files and wrapper functions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu`:
- Around line 56-65: Add a precondition check in invokeMamba2MTPSSMCacheUpdate
to validate params.ngroups and the head/group ratio before dispatch: ensure
params.ngroups > 0 and params.nheads % params.ngroups == 0 (so
heads_groups_ratio = nheads/ngroups is integral and non-zero) to prevent
division-by-zero and incorrect grouping used by mamba2MTPSSMCacheKernel.cuh; use
the existing TLLM_CHECK_WITH_INFO pattern (same as the head_dim check) to fail
fast with a clear message referencing ngroups and nheads before calling
MTP_DISPATCH_VEC_SIZE and launchMamba2MTPSSMCacheKernel.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h`:
- Around line 35-79: Update the Mamba2MTPSSMCacheParams struct documentation to
reflect the actual tensor layouts used by mamba2MTPSSMCacheKernel: change dt, A,
D, and dt_bias comments to indicate they are indexed only by head (no
head_dim/ssm_dim axes) and correct their shapes accordingly (e.g., dt: [bs,
cache_steps, nheads], A: [nheads], D/dt_bias: [nheads] or similarly 1D per head
as used in kernel), and convert all public comments to Doxygen C++ style (use
//! for single-line comments and //!< for member annotations) including the
prototype for invokeMamba2MTPSSMCacheUpdate; reference struct name
Mamba2MTPSSMCacheParams and the kernel usage locations
(mamba2MTPSSMCacheKernel.cuh lines mentioned) so callers provision tensors with
the right layouts.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh`:
- Around line 372-383: The kernel is restoring parent state from
retrieve_parent_token even when parent_step_idx == t or > t (future/unwritten
entries); update the guard in the RETRIEVE_PARENT_TOKEN block so you only load
when parent_step_idx >= 0 && parent_step_idx < t (not < cache_steps), e.g. in
the block around retrieve_parent_token[bs_id * cache_steps + t] before calling
mtp_load_vec_to_float for state_4_a/state_4_b with inter_base_a/inter_base_b and
stride_nheads_hdim_ssm_dim ensure the check uses t as the upper bound to only
restore from already-materialized steps.
- Around line 113-116: The current mtp_softplus uses __logf(1 +
__expf(dt_value)) which overflows for large dt_value (causing
xdt_val_a/xdt_val_b to become inf); replace it with a numerically stable branch:
for dt_value > 0 return dt_value + __logf(1.f + __expf(-dt_value)), otherwise
return __logf(1.f + __expf(dt_value)) so exponentials are computed on non-large
positive inputs and downstream state/output corruption is avoided; update the
mtp_softplus device function accordingly (it’s the function to change).

In `@cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp`:
- Around line 78-82: B.size(2) can be zero, causing undefined host behavior when
evaluating nheads % ngroups; add an explicit check that rejects ngroups == 0
before performing the modulo. Concretely, after computing int const ngroups =
B.size(2); add a TORCH_CHECK(ngroups > 0, "ngroups must be > 0") (or similar)
and then keep the existing TORCH_CHECK(nheads % ngroups == 0, "unsupported pair
of nheads and ngroups") so the modulo is only executed on a non-zero ngroups.
- Around line 50-63: Add a host-side validation that when ssm_batch_indices is
not provided (i.e., kernel will use direct bs_id indexing) the first dimension
of ssm is at least the batch size of x: check ssm.size(0) >= x.size(0) and raise
a TORCH_CHECK with a clear message if not; locate this guard near the existing
tensor shape/device checks around ssm, x, and intermediate_states in
mamba2MTPSSMCacheOp.cpp so it runs before launching the kernel and prevents
out-of-bounds access when ssm_batch_indices is null.
- Around line 209-210: The code calls at::cuda::getCurrentCUDAStream() without
setting the device, which can pick the thread-local device instead of the device
associated with params; fix by including <c10/cuda/CUDAGuard.h> and creating a
c10::cuda::CUDAGuard guard(params.device()) (or the appropriate device accessor
from params) immediately before calling at::cuda::getCurrentCUDAStream(), then
get the stream and call tk::invokeMamba2MTPSSMCacheUpdate(params, stream) so the
kernel is launched on the correct device.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 73658ce3-902d-4e56-bd1e-53fc964d780d

📥 Commits

Reviewing files that changed from the base of the PR and between 2b5c434 and 1d2bb3b.

📒 Files selected for processing (12)
  • cpp/tensorrt_llm/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec16.cu
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec4.cu
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec8.cu
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp
  • tensorrt_llm/_torch/modules/mamba/selective_state_update.py

Comment thread cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu
Comment thread cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h
Comment thread cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh Outdated
Comment thread cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp
Comment thread cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp Outdated
Comment thread cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp
JadoTu added 2 commits March 30, 2026 06:43
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented Mar 30, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40687 [ run ] triggered by Bot. Commit: 156a50b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40687 [ run ] completed with state SUCCESS. Commit: 156a50b
/LLM/main/L0_MergeRequest_PR pipeline #31716 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented Mar 31, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40860 [ run ] triggered by Bot. Commit: 33c9959 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40860 [ run ] completed with state SUCCESS. Commit: 33c9959
/LLM/main/L0_MergeRequest_PR pipeline #31867 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@JadoTu
Copy link
Copy Markdown
Collaborator Author

JadoTu commented Apr 1, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41057 [ run ] triggered by Bot. Commit: 33c9959 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41057 [ run ] completed with state SUCCESS. Commit: 33c9959
/LLM/main/L0_MergeRequest_PR pipeline #32034 completed with status: 'SUCCESS'

CI Report

Link to invocation

@JadoTu JadoTu merged commit 1b6a2bc into NVIDIA:main Apr 1, 2026
5 checks passed
karen-sy pushed a commit to karen-sy/TensorRT-LLM that referenced this pull request Apr 7, 2026
…ulative decoding (NVIDIA#12537)

Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.