Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[None][feat] Add triton paged attention for AutoDeploy#12642

Merged
suyoggupta merged 5 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
nv-auto-deploy:chenghao/triton_paged_attention_0331nv-auto-deploy/TensorRT-LLM:chenghao/triton_paged_attention_0331Copy head branch name to clipboard
Apr 2, 2026
Merged

[None][feat] Add triton paged attention for AutoDeploy#12642
suyoggupta merged 5 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
nv-auto-deploy:chenghao/triton_paged_attention_0331nv-auto-deploy/TensorRT-LLM:chenghao/triton_paged_attention_0331Copy head branch name to clipboard

Conversation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator

@nvchenghaoz nvchenghaoz commented Apr 1, 2026

Add two-stage flash-decode triton paged attention with HND layout support for the AutoDeploy attention backend, including unit tests.

Summary by CodeRabbit

  • New Features

    • Replaced previous attention implementation with an optimized Triton-based paged attention backend featuring two-stage FlashDecoding and adaptive optimizations for improved inference performance.
  • Tests

    • Added comprehensive test suite validating decode and prefill kernels, cache updates, and integration paths with optional comparisons against reference implementations.

Here is the comparison for nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024 CW H100. ISL / OSL 1k/1k: The default attention is TRTLLM (1416) vs. TRITON_PAGED (1451):

Screenshot 2026-04-01 at 3 08 17 PM Screenshot 2026-04-01 at 3 10 15 PM
260401_1416 (default config: TRTLLM)
concurrency,TTFT avg (ms),prefill_tps avg (tok/s),TPOT avg (ms)
1,53.909,19007.4,4.266
2,117.818,9718.9,5.010
4,110.373,9138.6,6.799
8,165.829,6223.6,9.225
16,4861.987,3376.5,13.951
32,522.902,2855.5,18.534
64,438.548,2671.9,23.005
128,617.709,2281.7,28.320
256,1668.507,1663.4,36.104


260401_1451 (Triton paged)
concurrency,TTFT avg (ms),prefill_tps avg (tok/s),TPOT avg (ms)
1,49.085,20383.9,4.172
2,104.615,10076.6,4.952
4,187.451,7445.3,6.618
8,180.637,5663.4,9.116
16,1283.926,2875.8,12.636
32,467.463,2473.3,17.537
64,513.821,2382.9,23.209
128,646.571,2024.6,28.071
256,1035.160,1657.9,37.356

Add two-stage flash-decode triton paged attention with HND layout
support for the AutoDeploy attention backend, including unit tests.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Made-with: Cursor
@nvchenghaoz nvchenghaoz requested a review from a team as a code owner April 1, 2026 00:42
@nvchenghaoz nvchenghaoz requested a review from MrGeva April 1, 2026 00:42
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 1, 2026

📝 Walkthrough

Walkthrough

A new Triton-based paged attention implementation for TensorRT-LLM auto-deploy custom ops introduces two-stage FlashDecoding for decode operations, optimized context processing with optional SDPA acceleration, and KV-cache updates. The public API exports were updated to reflect the new backend.

Changes

Cohort / File(s) Summary
Public API Update
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.py
Replaced triton_attention_with_paged_kv_cache export with triton_paged_attention in __all__ and updated module docstring.
Triton Paged Attention Implementation
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
Added comprehensive Triton kernels for two-stage FlashDecoding (decode), causal context kernel (prefill), and KV-cache writer. Implements metadata preparer and main MHA custom op with descriptor registration. Includes cache initialization via KVPagedResourceHandler with HND layout and fallback to optimized SDPA gather path.
Unit and Integration Tests
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
Comprehensive test suite validating individual kernels (triton_paged_decode, triton_paged_context, update_paged_kv_cache), higher-level dispatch paths, cache correctness, and optional FlashInfer baseline comparisons.

Sequence Diagram(s)

sequenceDiagram
    participant Input as Input (Q/K/V)
    participant Prep as Metadata<br/>Preparer
    participant Cache as KV Cache<br/>Update
    participant Decode as Decode Path
    participant Context as Context Path
    participant Output as Output

    Input->>Prep: position_ids, batch_info_host, cu_seqlen
    Prep-->>Input: batch_indices, positions
    Input->>Cache: k, v, batch_indices, positions, page_table
    Cache-->>Input: cache_updated
    Input->>Decode: q, cached_k/v (two-stage)
    Decode->>Decode: Stage 1: per-split partial outputs
    Decode->>Decode: Stage 2: reduce across splits
    Decode-->>Output: final_output (decode)
    Input->>Context: q, k, v (pages to contiguous)
    Context->>Context: Optional SDPA with is_causal=True
    Context-->>Output: final_output (context)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 69.44% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description lacks structured sections required by the template and omits critical implementation details. Expand description to include: explicit 'Description' section explaining what is being added and why; 'Test Coverage' section documenting test cases; confirmation of checklist items (especially CODING_GUIDELINES compliance, dependency scans, CODEOWNERS updates, and documentation).
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title '[None][feat] Add triton paged attention for AutoDeploy' clearly summarizes the main change: adding a triton paged attention implementation for AutoDeploy.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (2)

126-159: Consider documenting tensor shapes in docstring.

Per coding guidelines, documentation of Tensor-like arguments should include expected dimensions. The current docstring is minimal.

Suggested docstring improvement
 def update_paged_kv_cache(
     k: torch.Tensor,
     v: torch.Tensor,
     batch_indices: torch.Tensor,
     positions: torch.Tensor,
     kv_cache: torch.Tensor,
     kv_indices: torch.Tensor,
     kv_indptr: torch.Tensor,
 ) -> None:
-    """Update the combined paged KV cache with new K, V tensors."""
+    """Update the combined paged KV cache with new K, V tensors.
+
+    Args:
+        k: Key tensor [num_tokens, n_kv_heads, head_dim]
+        v: Value tensor [num_tokens, n_kv_heads, head_dim]
+        batch_indices: Sequence index per token [num_tokens]
+        positions: Position within sequence per token [num_tokens]
+        kv_cache: Paged cache [num_blocks, 2, n_kv_heads, page_size, head_dim]
+        kv_indices: Physical page indices (flattened)
+        kv_indptr: Cumulative page counts [num_seqs + 1]
+    """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 126 - 159, The update_paged_kv_cache docstring is too
minimal—expand it to document the expected tensor shapes and meanings for each
argument (e.g., k: (num_tokens, n_kv_heads, head_dim), v: (num_tokens,
n_kv_heads, head_dim), batch_indices: (num_tokens,) int64, positions:
(num_tokens,) int64, kv_cache: (batch, n_kv_heads, num_pages, page_size,
head_dim) or the actual layout used, kv_indices: (batch, num_kv_heads,
num_pages) int32, kv_indptr: (batch+1,) or whatever indexing format is used),
and note that the function updates kv_cache in-place and returns None; include
any assumptions about contiguous memory/strides (cache_stride_block,
cache_stride_kv, cache_stride_head, cache_stride_token) so callers know layout
requirements and axis ordering when invoking update_paged_kv_cache.

849-866: GPU sync for multi-sequence batches may impact latency.

Line 853 uses .item() which synchronizes CPU-GPU. While documented as necessary for variable-length sequences, this could add latency in multi-sequence prefill scenarios. The comment explains the rationale well.

Also, the SDPA path condition at line 865 (kv_indices.shape[0] == total_expected_pages) restricts usage to batches where all sequences have identical page counts, which may rarely be true in practice.

Consider tracking this as a future optimization opportunity if multi-sequence prefill latency becomes a bottleneck.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`
around lines 849 - 866, The code forces a GPU sync via q_lens.max().item() when
computing max_q_len and then strictly requires kv_indices.shape[0] ==
total_expected_pages, which is brittle; change max_q_len to remain a tensor
(e.g., max_q_len = q_lens.max()) and propagate tensor arithmetic for max_pages
and total_expected_pages so no .item() is called, then update the SDPA predicate
in use_sdpa (the condition built from max_q_len, max_pages and kv_indices) to
operate on tensors and relax the kv_indices check (e.g., allow
kv_indices.shape[0] >= total_expected_pages or remove the exact-equality
requirement) so multi-sequence batches don't unnecessarily miss the SDPA path;
adjust any downstream uses of max_q_len/max_pages to handle tensor vs int
accordingly (referencing qo_indptr, max_q_len, max_pages, total_expected_pages,
use_sdpa, kv_indices).
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py (1)

220-299: Context kernel tests lack direct PyTorch reference comparison.

Unlike TestTritonPagedDecodeKernel, this class only tests for shape/NaN/Inf sanity but doesn't include a direct correctness test against PyTorch SDPA reference. The TestFlashInferComparison.test_prefill_vs_flashinfer provides correctness validation but requires FlashInfer to be installed.

Consider adding a PyTorch SDPA reference test similar to test_decode_kernel_vs_pytorch_reference for cases where FlashInfer is unavailable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`
around lines 220 - 299, Add a new test method (e.g.,
test_context_kernel_vs_pytorch_reference) that mirrors test_context_kernel_basic
but computes a direct PyTorch SDPA reference and compares it to
triton_paged_context; reconstruct the full per-batch per-head KV sequence from
kv_cache using kv_indices, kv_indptr, and seq_len_with_cache (or reassemble k/v
from the original k and v tensors used to populate the cache), cast tensors to
float32 for the reference, compute scaled dot-product attention per head (use
torch.matmul and softmax with sm_scale) to produce a reference output shaped
like q, then compare triton_paged_context(...) to the reference with
torch.allclose (use reasonable atol/rtol) and keep existing shape/NaN/Inf
checks; reference the functions triton_paged_context, update_paged_kv_cache, and
create_paged_kv_cache so the test locates the same fixtures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1168-1176: The current branch retrieves attn_mask, dropout_p, and
is_causal via extract_op_args for source_attn_node and only logs unsupported
combos with ad_logger.debug, which can silently continue; update the handling in
the block that checks attn_mask, dropout_p, and is_causal (the variables
attn_mask, dropout_p, is_causal and the logging call ad_logger.debug referencing
source_attn_node) to either (a) escalate to ad_logger.warning with a clear
message that these features are unsupported, or (b) for stricter validation,
raise NotImplementedError per-case (e.g., if attn_mask is not None raise
NotImplementedError("attn_mask not supported by triton_paged backend"),
similarly for dropout_p != 0.0 and not is_causal) so callers cannot proceed with
unsupported configurations.
- Around line 55-60: _get_num_sms currently always queries device 0 using
_NUM_SMS which breaks multi-GPU runs; change it to query the current CUDA device
(use torch.cuda.current_device()) and maintain a per-device cache (e.g., replace
_NUM_SMS with a dict like _NUM_SMS_BY_DEVICE keyed by device index) so
_get_num_sms reads the current device index, checks the cache, and queries
torch.cuda.get_device_properties(device).multi_processor_count only when
missing.

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 503-508: The test function test_decode_vs_flashinfer imports
flashinfer unprotected causing ImportError when the package is absent; replace
the direct import with a pytest.importorskip call (e.g., call
pytest.importorskip("flashinfer") at the start of test_decode_vs_flashinfer) so
the test is skipped automatically when FlashInfer isn't installed and subsequent
references in the test use the returned module or rely on the skip.
- Around line 610-613: The skip decorator is misused: remove the
`@pytest.mark.skipif`(not pytest.importorskip(...)) decorator and instead call
pytest.importorskip("flashinfer", reason="FlashInfer not installed") at the
start of each test that requires FlashInfer (e.g., inside
test_decode_vs_flashinfer and inside the test function(s) that currently have
the skipif decorator); this ensures the test is skipped cleanly when FlashInfer
is absent and avoids the confusing decorator logic—locate the decorator usage
and replace it with an importorskip call at the top of the corresponding test
functions.

---

Nitpick comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 126-159: The update_paged_kv_cache docstring is too minimal—expand
it to document the expected tensor shapes and meanings for each argument (e.g.,
k: (num_tokens, n_kv_heads, head_dim), v: (num_tokens, n_kv_heads, head_dim),
batch_indices: (num_tokens,) int64, positions: (num_tokens,) int64, kv_cache:
(batch, n_kv_heads, num_pages, page_size, head_dim) or the actual layout used,
kv_indices: (batch, num_kv_heads, num_pages) int32, kv_indptr: (batch+1,) or
whatever indexing format is used), and note that the function updates kv_cache
in-place and returns None; include any assumptions about contiguous
memory/strides (cache_stride_block, cache_stride_kv, cache_stride_head,
cache_stride_token) so callers know layout requirements and axis ordering when
invoking update_paged_kv_cache.
- Around line 849-866: The code forces a GPU sync via q_lens.max().item() when
computing max_q_len and then strictly requires kv_indices.shape[0] ==
total_expected_pages, which is brittle; change max_q_len to remain a tensor
(e.g., max_q_len = q_lens.max()) and propagate tensor arithmetic for max_pages
and total_expected_pages so no .item() is called, then update the SDPA predicate
in use_sdpa (the condition built from max_q_len, max_pages and kv_indices) to
operate on tensors and relax the kv_indices check (e.g., allow
kv_indices.shape[0] >= total_expected_pages or remove the exact-equality
requirement) so multi-sequence batches don't unnecessarily miss the SDPA path;
adjust any downstream uses of max_q_len/max_pages to handle tensor vs int
accordingly (referencing qo_indptr, max_q_len, max_pages, total_expected_pages,
use_sdpa, kv_indices).

In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 220-299: Add a new test method (e.g.,
test_context_kernel_vs_pytorch_reference) that mirrors test_context_kernel_basic
but computes a direct PyTorch SDPA reference and compares it to
triton_paged_context; reconstruct the full per-batch per-head KV sequence from
kv_cache using kv_indices, kv_indptr, and seq_len_with_cache (or reassemble k/v
from the original k and v tensors used to populate the cache), cast tensors to
float32 for the reference, compute scaled dot-product attention per head (use
torch.matmul and softmax with sm_scale) to produce a reference output shaped
like q, then compare triton_paged_context(...) to the reference with
torch.allclose (use reasonable atol/rtol) and keep existing shape/NaN/Inf
checks; reference the functions triton_paged_context, update_paged_kv_cache, and
create_paged_kv_cache so the test locates the same fixtures.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 531ca809-24fe-4681-b510-398139df3ea8

📥 Commits

Reviewing files that changed from the base of the PR and between 85ab1c0 and d7801ac.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1"

@suyoggupta
Copy link
Copy Markdown
Collaborator

can you compare perf for a model like llama3.1 fp8?

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41241 [ run ] triggered by Bot. Commit: 530b476 Link to invocation

@suyoggupta suyoggupta requested a review from bmarimuthu-nv April 1, 2026 19:05
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz requested a review from a team as a code owner April 1, 2026 22:14
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41271 [ run ] triggered by Bot. Commit: df3b634 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41271 [ run ] completed with state SUCCESS. Commit: df3b634
/LLM/main/L0_MergeRequest_PR pipeline #32229 (Partly Tested) completed with status: 'SUCCESS'

CI Report

Link to invocation

@suyoggupta suyoggupta enabled auto-merge (squash) April 2, 2026 15:34
@suyoggupta
Copy link
Copy Markdown
Collaborator

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41454 [ run ] triggered by Bot. Commit: 753d71c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41454 [ run ] completed with state SUCCESS. Commit: 753d71c
/LLM/main/L0_MergeRequest_PR pipeline #32384 (Partly Tested) completed with status: 'SUCCESS'

CI Report

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41489 [ run ] triggered by Bot. Commit: 753d71c Link to invocation

bmarimuthu-nv added a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Apr 2, 2026
…VIDIA#12642)

Adds triton-based paged attention implementation for AutoDeploy,
supporting head_dim=512 which is needed for Gemma4.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot skip --comment "Only touches AutoDeploy related tests / files, the autodeploy pipeline passed."

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41501 [ skip ] triggered by Bot. Commit: 753d71c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41501 [ skip ] completed with state SUCCESS. Commit: 753d71c
Skipping testing for commit 753d71c

Link to invocation

@suyoggupta suyoggupta merged commit 7aa7818 into NVIDIA:main Apr 2, 2026
5 checks passed
govind-ramnarayan pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Apr 6, 2026
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
karen-sy pushed a commit to karen-sy/TensorRT-LLM that referenced this pull request Apr 7, 2026
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.