[None][feat] Add triton paged attention for AutoDeploy#12642
[None][feat] Add triton paged attention for AutoDeploy#12642suyoggupta merged 5 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom nv-auto-deploy:chenghao/triton_paged_attention_0331nv-auto-deploy/TensorRT-LLM:chenghao/triton_paged_attention_0331Copy head branch name to clipboard
Conversation
Add two-stage flash-decode triton paged attention with HND layout support for the AutoDeploy attention backend, including unit tests. Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Made-with: Cursor
📝 WalkthroughWalkthroughA new Triton-based paged attention implementation for TensorRT-LLM auto-deploy custom ops introduces two-stage FlashDecoding for decode operations, optimized context processing with optional SDPA acceleration, and KV-cache updates. The public API exports were updated to reflect the new backend. Changes
Sequence Diagram(s)sequenceDiagram
participant Input as Input (Q/K/V)
participant Prep as Metadata<br/>Preparer
participant Cache as KV Cache<br/>Update
participant Decode as Decode Path
participant Context as Context Path
participant Output as Output
Input->>Prep: position_ids, batch_info_host, cu_seqlen
Prep-->>Input: batch_indices, positions
Input->>Cache: k, v, batch_indices, positions, page_table
Cache-->>Input: cache_updated
Input->>Decode: q, cached_k/v (two-stage)
Decode->>Decode: Stage 1: per-split partial outputs
Decode->>Decode: Stage 2: reduce across splits
Decode-->>Output: final_output (decode)
Input->>Context: q, k, v (pages to contiguous)
Context->>Context: Optional SDPA with is_causal=True
Context-->>Output: final_output (context)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (2)
126-159: Consider documenting tensor shapes in docstring.Per coding guidelines, documentation of Tensor-like arguments should include expected dimensions. The current docstring is minimal.
Suggested docstring improvement
def update_paged_kv_cache( k: torch.Tensor, v: torch.Tensor, batch_indices: torch.Tensor, positions: torch.Tensor, kv_cache: torch.Tensor, kv_indices: torch.Tensor, kv_indptr: torch.Tensor, ) -> None: - """Update the combined paged KV cache with new K, V tensors.""" + """Update the combined paged KV cache with new K, V tensors. + + Args: + k: Key tensor [num_tokens, n_kv_heads, head_dim] + v: Value tensor [num_tokens, n_kv_heads, head_dim] + batch_indices: Sequence index per token [num_tokens] + positions: Position within sequence per token [num_tokens] + kv_cache: Paged cache [num_blocks, 2, n_kv_heads, page_size, head_dim] + kv_indices: Physical page indices (flattened) + kv_indptr: Cumulative page counts [num_seqs + 1] + """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py` around lines 126 - 159, The update_paged_kv_cache docstring is too minimal—expand it to document the expected tensor shapes and meanings for each argument (e.g., k: (num_tokens, n_kv_heads, head_dim), v: (num_tokens, n_kv_heads, head_dim), batch_indices: (num_tokens,) int64, positions: (num_tokens,) int64, kv_cache: (batch, n_kv_heads, num_pages, page_size, head_dim) or the actual layout used, kv_indices: (batch, num_kv_heads, num_pages) int32, kv_indptr: (batch+1,) or whatever indexing format is used), and note that the function updates kv_cache in-place and returns None; include any assumptions about contiguous memory/strides (cache_stride_block, cache_stride_kv, cache_stride_head, cache_stride_token) so callers know layout requirements and axis ordering when invoking update_paged_kv_cache.
849-866: GPU sync for multi-sequence batches may impact latency.Line 853 uses
.item()which synchronizes CPU-GPU. While documented as necessary for variable-length sequences, this could add latency in multi-sequence prefill scenarios. The comment explains the rationale well.Also, the SDPA path condition at line 865 (
kv_indices.shape[0] == total_expected_pages) restricts usage to batches where all sequences have identical page counts, which may rarely be true in practice.Consider tracking this as a future optimization opportunity if multi-sequence prefill latency becomes a bottleneck.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py` around lines 849 - 866, The code forces a GPU sync via q_lens.max().item() when computing max_q_len and then strictly requires kv_indices.shape[0] == total_expected_pages, which is brittle; change max_q_len to remain a tensor (e.g., max_q_len = q_lens.max()) and propagate tensor arithmetic for max_pages and total_expected_pages so no .item() is called, then update the SDPA predicate in use_sdpa (the condition built from max_q_len, max_pages and kv_indices) to operate on tensors and relax the kv_indices check (e.g., allow kv_indices.shape[0] >= total_expected_pages or remove the exact-equality requirement) so multi-sequence batches don't unnecessarily miss the SDPA path; adjust any downstream uses of max_q_len/max_pages to handle tensor vs int accordingly (referencing qo_indptr, max_q_len, max_pages, total_expected_pages, use_sdpa, kv_indices).tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py (1)
220-299: Context kernel tests lack direct PyTorch reference comparison.Unlike
TestTritonPagedDecodeKernel, this class only tests for shape/NaN/Inf sanity but doesn't include a direct correctness test against PyTorch SDPA reference. TheTestFlashInferComparison.test_prefill_vs_flashinferprovides correctness validation but requires FlashInfer to be installed.Consider adding a PyTorch SDPA reference test similar to
test_decode_kernel_vs_pytorch_referencefor cases where FlashInfer is unavailable.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py` around lines 220 - 299, Add a new test method (e.g., test_context_kernel_vs_pytorch_reference) that mirrors test_context_kernel_basic but computes a direct PyTorch SDPA reference and compares it to triton_paged_context; reconstruct the full per-batch per-head KV sequence from kv_cache using kv_indices, kv_indptr, and seq_len_with_cache (or reassemble k/v from the original k and v tensors used to populate the cache), cast tensors to float32 for the reference, compute scaled dot-product attention per head (use torch.matmul and softmax with sm_scale) to produce a reference output shaped like q, then compare triton_paged_context(...) to the reference with torch.allclose (use reasonable atol/rtol) and keep existing shape/NaN/Inf checks; reference the functions triton_paged_context, update_paged_kv_cache, and create_paged_kv_cache so the test locates the same fixtures.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1168-1176: The current branch retrieves attn_mask, dropout_p, and
is_causal via extract_op_args for source_attn_node and only logs unsupported
combos with ad_logger.debug, which can silently continue; update the handling in
the block that checks attn_mask, dropout_p, and is_causal (the variables
attn_mask, dropout_p, is_causal and the logging call ad_logger.debug referencing
source_attn_node) to either (a) escalate to ad_logger.warning with a clear
message that these features are unsupported, or (b) for stricter validation,
raise NotImplementedError per-case (e.g., if attn_mask is not None raise
NotImplementedError("attn_mask not supported by triton_paged backend"),
similarly for dropout_p != 0.0 and not is_causal) so callers cannot proceed with
unsupported configurations.
- Around line 55-60: _get_num_sms currently always queries device 0 using
_NUM_SMS which breaks multi-GPU runs; change it to query the current CUDA device
(use torch.cuda.current_device()) and maintain a per-device cache (e.g., replace
_NUM_SMS with a dict like _NUM_SMS_BY_DEVICE keyed by device index) so
_get_num_sms reads the current device index, checks the cache, and queries
torch.cuda.get_device_properties(device).multi_processor_count only when
missing.
In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 503-508: The test function test_decode_vs_flashinfer imports
flashinfer unprotected causing ImportError when the package is absent; replace
the direct import with a pytest.importorskip call (e.g., call
pytest.importorskip("flashinfer") at the start of test_decode_vs_flashinfer) so
the test is skipped automatically when FlashInfer isn't installed and subsequent
references in the test use the returned module or rely on the skip.
- Around line 610-613: The skip decorator is misused: remove the
`@pytest.mark.skipif`(not pytest.importorskip(...)) decorator and instead call
pytest.importorskip("flashinfer", reason="FlashInfer not installed") at the
start of each test that requires FlashInfer (e.g., inside
test_decode_vs_flashinfer and inside the test function(s) that currently have
the skipif decorator); this ensures the test is skipped cleanly when FlashInfer
is absent and avoids the confusing decorator logic—locate the decorator usage
and replace it with an importorskip call at the top of the corresponding test
functions.
---
Nitpick comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 126-159: The update_paged_kv_cache docstring is too minimal—expand
it to document the expected tensor shapes and meanings for each argument (e.g.,
k: (num_tokens, n_kv_heads, head_dim), v: (num_tokens, n_kv_heads, head_dim),
batch_indices: (num_tokens,) int64, positions: (num_tokens,) int64, kv_cache:
(batch, n_kv_heads, num_pages, page_size, head_dim) or the actual layout used,
kv_indices: (batch, num_kv_heads, num_pages) int32, kv_indptr: (batch+1,) or
whatever indexing format is used), and note that the function updates kv_cache
in-place and returns None; include any assumptions about contiguous
memory/strides (cache_stride_block, cache_stride_kv, cache_stride_head,
cache_stride_token) so callers know layout requirements and axis ordering when
invoking update_paged_kv_cache.
- Around line 849-866: The code forces a GPU sync via q_lens.max().item() when
computing max_q_len and then strictly requires kv_indices.shape[0] ==
total_expected_pages, which is brittle; change max_q_len to remain a tensor
(e.g., max_q_len = q_lens.max()) and propagate tensor arithmetic for max_pages
and total_expected_pages so no .item() is called, then update the SDPA predicate
in use_sdpa (the condition built from max_q_len, max_pages and kv_indices) to
operate on tensors and relax the kv_indices check (e.g., allow
kv_indices.shape[0] >= total_expected_pages or remove the exact-equality
requirement) so multi-sequence batches don't unnecessarily miss the SDPA path;
adjust any downstream uses of max_q_len/max_pages to handle tensor vs int
accordingly (referencing qo_indptr, max_q_len, max_pages, total_expected_pages,
use_sdpa, kv_indices).
In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py`:
- Around line 220-299: Add a new test method (e.g.,
test_context_kernel_vs_pytorch_reference) that mirrors test_context_kernel_basic
but computes a direct PyTorch SDPA reference and compares it to
triton_paged_context; reconstruct the full per-batch per-head KV sequence from
kv_cache using kv_indices, kv_indptr, and seq_len_with_cache (or reassemble k/v
from the original k and v tensors used to populate the cache), cast tensors to
float32 for the reference, compute scaled dot-product attention per head (use
torch.matmul and softmax with sm_scale) to produce a reference output shaped
like q, then compare triton_paged_context(...) to the reference with
torch.allclose (use reasonable atol/rtol) and keep existing shape/NaN/Inf
checks; reference the functions triton_paged_context, update_paged_kv_cache, and
create_paged_kv_cache so the test locates the same fixtures.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 531ca809-24fe-4681-b510-398139df3ea8
📒 Files selected for processing (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/__init__.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.py
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1" |
|
can you compare perf for a model like llama3.1 fp8? |
|
PR_Github #41241 [ run ] triggered by Bot. Commit: |
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1" |
|
PR_Github #41271 [ run ] triggered by Bot. Commit: |
|
PR_Github #41271 [ run ] completed with state |
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_B200-AutoDeploy-1" |
|
PR_Github #41454 [ run ] triggered by Bot. Commit: |
|
PR_Github #41454 [ run ] completed with state |
|
/bot run |
|
PR_Github #41489 [ run ] triggered by Bot. Commit: |
…VIDIA#12642) Adds triton-based paged attention implementation for AutoDeploy, supporting head_dim=512 which is needed for Gemma4. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot skip --comment "Only touches AutoDeploy related tests / files, the autodeploy pipeline passed." |
|
PR_Github #41501 [ skip ] triggered by Bot. Commit: |
|
PR_Github #41501 [ skip ] completed with state |
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Add two-stage flash-decode triton paged attention with HND layout support for the AutoDeploy attention backend, including unit tests.
Summary by CodeRabbit
New Features
Tests
Here is the comparison for nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024 CW H100. ISL / OSL 1k/1k: The default attention is TRTLLM (1416) vs. TRITON_PAGED (1451):