[https://nvbugs/5983390][perf] Multiple host perf optimizations for DSA part#12581
[https://nvbugs/5983390][perf] Multiple host perf optimizations for DSA part#12581longlee0622 merged 6 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom hyukn:fix/5983390_torch_compilehyukn/TensorRT-LLM:fix/5983390_torch_compileCopy head branch name to clipboard
Conversation
📝 WalkthroughWalkthroughRefactored two compiled helper functions into instance methods: moved Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
1-1:⚠️ Potential issue | 🟠 MajorAdd/update NVIDIA copyright header in this modified file.
This file was modified but does not include the required OSS copyright header/update.
As per coding guidelines: “All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` at line 1, This file (dsa.py) is missing the required NVIDIA OSS copyright header; add the official NVIDIA copyright header (including the year of the latest meaningful modification) at the very top of the file before any code or imports (i.e., insert it above the existing import math line), ensuring the header format matches other TensorRT-LLM files in the repo.tensorrt_llm/_torch/speculative/mtp.py (1)
1-1:⚠️ Potential issue | 🟠 MajorAdd/update NVIDIA copyright header in this modified file.
This file was modified but does not include the required OSS copyright header/update.
As per coding guidelines: “All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/speculative/mtp.py` at line 1, This file is missing the required NVIDIA OSS copyright header; add/update a proper NVIDIA copyright header at the top of the module (above the first import) that includes the year of the latest meaningful modification and the standard NVIDIA OSS header text, so the tensorrt_llm._torch.speculative.mtp module begins with the required NVIDIA copyright block before the existing import sys line.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
586-587: Remove the redundantdeviceargument from_get_dense_topk_indices.
deviceis passed by callers but immediately overwritten, so it adds API noise and unnecessary call-time variability.♻️ Suggested cleanup
- def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens, device): - device = kv_lens.device + def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens): + device = kv_lens.device @@ - self._get_dense_topk_indices( + self._get_dense_topk_indices( self.seq_lens_cuda[:self.num_contexts], - kv_lens[:self.num_contexts], self.num_ctx_tokens, - device), + kv_lens[:self.num_contexts], self.num_ctx_tokens), @@ - ctx_range, :] = self._get_dense_topk_indices( + ctx_range, :] = self._get_dense_topk_indices( self.seq_lens[:self.num_contexts], - kv_lens[:self.num_contexts], self.num_ctx_tokens, - device) + kv_lens[:self.num_contexts], self.num_ctx_tokens) @@ - self._get_dense_topk_indices( + self._get_dense_topk_indices( self.seq_lens_cuda[self.num_contexts:self.num_seqs], kv_lens[self.num_contexts:self.num_seqs], - self.num_tokens - self.num_ctx_tokens, device), + self.num_tokens - self.num_ctx_tokens), @@ - gen_range, :] = self._get_dense_topk_indices( + gen_range, :] = self._get_dense_topk_indices( self.seq_lens[self.num_contexts:self.num_seqs], kv_lens[self.num_contexts:self.num_seqs], - self.num_tokens - self.num_ctx_tokens, device) + self.num_tokens - self.num_ctx_tokens)Also applies to: 612-615, 619-622, 631-634, 638-641
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 586 - 587, The _get_dense_topk_indices function currently accepts a device parameter but immediately overwrites it with kv_lens.device; remove the redundant device parameter from the function signature and delete the line "device = kv_lens.device", update all internal uses to rely on kv_lens.device directly, and update every call site to stop passing device; apply the same change to the sibling methods that mirror this pattern (the other definitions at the ranges you noted) so their signatures and callers are consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Around line 1135-1140: The helper prepare_position_ids_and_last_tokens
(decorated with `@torch.compile`) currently takes attn_metadata and accesses
attn_metadata.seq_lens_cuda, which creates Dynamo guards; change its signature
to accept seq_lens_cuda (e.g., seq_lens_cuda: Tensor) instead of attn_metadata,
use seq_lens_cuda directly in the torch.cumsum call to compute last_tokens_idx,
and update any callers (including the similar call site around the other usage)
to pass attn_metadata.seq_lens_cuda rather than the whole attn_metadata object
so the compiled graph is stable.
---
Outside diff comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Line 1: This file (dsa.py) is missing the required NVIDIA OSS copyright
header; add the official NVIDIA copyright header (including the year of the
latest meaningful modification) at the very top of the file before any code or
imports (i.e., insert it above the existing import math line), ensuring the
header format matches other TensorRT-LLM files in the repo.
In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Line 1: This file is missing the required NVIDIA OSS copyright header;
add/update a proper NVIDIA copyright header at the top of the module (above the
first import) that includes the year of the latest meaningful modification and
the standard NVIDIA OSS header text, so the tensorrt_llm._torch.speculative.mtp
module begins with the required NVIDIA copyright block before the existing
import sys line.
---
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 586-587: The _get_dense_topk_indices function currently accepts a
device parameter but immediately overwrites it with kv_lens.device; remove the
redundant device parameter from the function signature and delete the line
"device = kv_lens.device", update all internal uses to rely on kv_lens.device
directly, and update every call site to stop passing device; apply the same
change to the sibling methods that mirror this pattern (the other definitions at
the ranges you noted) so their signatures and callers are consistent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 18d0ec17-1df3-4e43-a34b-fea417d352b2
📒 Files selected for processing (2)
tensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/speculative/mtp.py
|
/bot run --disable-fail-fast |
|
PR_Github #40702 [ run ] triggered by Bot. Commit: |
|
PR_Github #40702 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40738 [ run ] triggered by Bot. Commit: |
|
PR_Github #40738 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #40793 [ run ] triggered by Bot. Commit: |
|
PR_Github #40793 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40904 [ run ] triggered by Bot. Commit: |
mikeiovine
left a comment
There was a problem hiding this comment.
MTP changes look OK, I don't have context on the other stuff
|
PR_Github #40904 [ run ] completed with state |
| self._pool_cache_valid = True | ||
|
|
||
| @maybe_compile(dynamic=True) | ||
| def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens, device): |
There was a problem hiding this comment.
device argument is unused.
|
PR_Github #41113 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41226 [ run ] triggered by Bot. Commit: |
|
PR_Github #41226 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41306 [ run ] triggered by Bot. Commit: |
…ompile in speculative decoding. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
…are_pool_view. This reduces about 50us * 2 per layer, totally 4~5 ms in a forward step. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
fc5eebd to
66f6b07
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41332 [ run ] triggered by Bot. Commit: |
|
PR_Github #41306 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41411 [ run ] triggered by Bot. Commit: |
… optimizations for DSA part (#12681) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
|
PR_Github #41411 [ run ] completed with state |
…ns for DSA part (NVIDIA#12581)" This reverts commit edbb4b2.
…t perf optimizations for DSA part (NVIDIA#12681) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
…SA part (NVIDIA#12581) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
…t perf optimizations for DSA part (NVIDIA#12681) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
…t perf optimizations for DSA part (NVIDIA#12681) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
…t perf optimizations for DSA part (NVIDIA#12681) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Summary by CodeRabbit
Description
This PR reduces host-side overhead in the DSA (Dense Sparse Attention) path through several complementary optimizations:
Cache step-invariant values across layers (
dsa.py): Pool view, block table slices, request index slices, and stride factor are now computed once per forward step in_ensure_pool_view_cached()and reused across all layers, eliminating redundant Python/CUDA overhead (~50us x 2 per layer, totaling ~4-5ms per forward step).Replace Triton kernels with C++ custom ops (
dsa.py,kernel.py,cpp_custom_ops.py, new C++ files):triton_gather_k_cacheandtriton_convert_req_index_to_global_indexare replaced withtrtllm::indexer_k_cache_gather_opandtrtllm::convert_req_index_to_global. This avoids Triton's host-side compilation/dispatch overhead.Fix
torch.compilerecompilation in MTP speculative decoding (mtp.py): Movedtorch.compile-decorated closures (prepare_position_ids_and_last_tokens,update_kv_lens) out of method bodies into proper class methods, preventing repeated graph tracing on every call.Relax unnecessary
tl.constexprannotations (kernel.py): Removedtl.constexprfrom Triton kernel parameters (max_num_blocks_per_req,BLOCK_SIZE,stride_factor,layer_id) that don't need to be compile-time constants, reducing Triton recompilation across layers.Test Coverage
tests/unittest/_torch/attention/sparse/test_cpp_custom_ops.py— covers bothindexer_k_cache_gather_opandconvert_req_index_to_globalC++ custom ops.tests/unittest/_torch/attention/sparse/test_triton_gather_k_cache.py(replaced by the new C++ op tests).PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.