Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[https://nvbugs/5983390][perf] Multiple host perf optimizations for DSA part#12581

Merged
longlee0622 merged 6 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
hyukn:fix/5983390_torch_compilehyukn/TensorRT-LLM:fix/5983390_torch_compileCopy head branch name to clipboard
Apr 2, 2026
Merged

[https://nvbugs/5983390][perf] Multiple host perf optimizations for DSA part#12581
longlee0622 merged 6 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
hyukn:fix/5983390_torch_compilehyukn/TensorRT-LLM:fix/5983390_torch_compileCopy head branch name to clipboard

Conversation

@hyukn
Copy link
Copy Markdown
Collaborator

@hyukn hyukn commented Mar 30, 2026

Summary by CodeRabbit

  • Refactor
    • Optimized internal attention and speculative decoding processing by restructuring compiled function organization for improved performance efficiency.

Description

This PR reduces host-side overhead in the DSA (Dense Sparse Attention) path through several complementary optimizations:

  1. Cache step-invariant values across layers (dsa.py): Pool view, block table slices, request index slices, and stride factor are now computed once per forward step in _ensure_pool_view_cached() and reused across all layers, eliminating redundant Python/CUDA overhead (~50us x 2 per layer, totaling ~4-5ms per forward step).

  2. Replace Triton kernels with C++ custom ops (dsa.py, kernel.py, cpp_custom_ops.py, new C++ files): triton_gather_k_cache and triton_convert_req_index_to_global_index are replaced with trtllm::indexer_k_cache_gather_op and trtllm::convert_req_index_to_global. This avoids Triton's host-side compilation/dispatch overhead.

  3. Fix torch.compile recompilation in MTP speculative decoding (mtp.py): Moved torch.compile-decorated closures (prepare_position_ids_and_last_tokens, update_kv_lens) out of method bodies into proper class methods, preventing repeated graph tracing on every call.

  4. Relax unnecessary tl.constexpr annotations (kernel.py): Removed tl.constexpr from Triton kernel parameters (max_num_blocks_per_req, BLOCK_SIZE, stride_factor, layer_id) that don't need to be compile-time constants, reducing Triton recompilation across layers.

Test Coverage

  • New unit tests: tests/unittest/_torch/attention/sparse/test_cpp_custom_ops.py — covers both indexer_k_cache_gather_op and convert_req_index_to_global C++ custom ops.
  • Removed stale test: tests/unittest/_torch/attention/sparse/test_triton_gather_k_cache.py (replaced by the new C++ op tests).
  • Existing DSA e2e and sparse attention tests cover the integrated behavior.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@hyukn hyukn requested review from chang-l, lfr-0531 and liji-nv March 30, 2026 06:38
@hyukn hyukn requested review from a team as code owners March 30, 2026 06:38
@hyukn hyukn requested review from QiJune and zheyuf March 30, 2026 06:38
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 30, 2026

📝 Walkthrough

Walkthrough

Refactored two compiled helper functions into instance methods: moved _get_dense_topk_indices from a nested closure to a compiled method in DSA attention metadata with added device parameter handling, and extracted prepare_position_ids_and_last_tokens as a compiled method in MTP worker while replacing an inner compiled closure with direct in-place updates for KV-lens operations.

Changes

Cohort / File(s) Summary
DSA Attention Backend
tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Refactored _get_dense_topk_indices from a nested @maybe_compile closure to a compiled instance method on DSAtrtllmAttentionMetadata. Added device parameter to method signature with logic to derive device from kv_lens.device. Updated callers in context and generation buffer paths to invoke self._get_dense_topk_indices(...) with num_tokens adjustments.
MTP Speculative Worker
tensorrt_llm/_torch/speculative/mtp.py
Extracted prepare_position_ids_and_last_tokens as a @torch.compile(max_autotune=True) instance method on MTPEagleWorker, moving squeeze-and-cumsum logic from forward's inner closure. Removed compiled update_kv_lens closure and replaced with direct in-place update attn_metadata.kv_lens_cuda[:batch_size] += 1 for iterations where i > 0.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: refactoring torch.compile usage to reduce host overhead in speculative decoding (DSA/MTP), with clear reference to the NVBugs tracking ID.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering four distinct optimization strategies with clear technical details, expected performance improvements, test coverage, and completed checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)

1-1: ⚠️ Potential issue | 🟠 Major

Add/update NVIDIA copyright header in this modified file.

This file was modified but does not include the required OSS copyright header/update.

As per coding guidelines: “All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` at line 1, This file
(dsa.py) is missing the required NVIDIA OSS copyright header; add the official
NVIDIA copyright header (including the year of the latest meaningful
modification) at the very top of the file before any code or imports (i.e.,
insert it above the existing import math line), ensuring the header format
matches other TensorRT-LLM files in the repo.
tensorrt_llm/_torch/speculative/mtp.py (1)

1-1: ⚠️ Potential issue | 🟠 Major

Add/update NVIDIA copyright header in this modified file.

This file was modified but does not include the required OSS copyright header/update.

As per coding guidelines: “All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification.”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/speculative/mtp.py` at line 1, This file is missing the
required NVIDIA OSS copyright header; add/update a proper NVIDIA copyright
header at the top of the module (above the first import) that includes the year
of the latest meaningful modification and the standard NVIDIA OSS header text,
so the tensorrt_llm._torch.speculative.mtp module begins with the required
NVIDIA copyright block before the existing import sys line.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)

586-587: Remove the redundant device argument from _get_dense_topk_indices.

device is passed by callers but immediately overwritten, so it adds API noise and unnecessary call-time variability.

♻️ Suggested cleanup
-    def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens, device):
-        device = kv_lens.device
+    def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens):
+        device = kv_lens.device
@@
-                    self._get_dense_topk_indices(
+                    self._get_dense_topk_indices(
                         self.seq_lens_cuda[:self.num_contexts],
-                        kv_lens[:self.num_contexts], self.num_ctx_tokens,
-                        device),
+                        kv_lens[:self.num_contexts], self.num_ctx_tokens),
@@
-                    ctx_range, :] = self._get_dense_topk_indices(
+                    ctx_range, :] = self._get_dense_topk_indices(
                         self.seq_lens[:self.num_contexts],
-                        kv_lens[:self.num_contexts], self.num_ctx_tokens,
-                        device)
+                        kv_lens[:self.num_contexts], self.num_ctx_tokens)
@@
-                    self._get_dense_topk_indices(
+                    self._get_dense_topk_indices(
                         self.seq_lens_cuda[self.num_contexts:self.num_seqs],
                         kv_lens[self.num_contexts:self.num_seqs],
-                        self.num_tokens - self.num_ctx_tokens, device),
+                        self.num_tokens - self.num_ctx_tokens),
@@
-                    gen_range, :] = self._get_dense_topk_indices(
+                    gen_range, :] = self._get_dense_topk_indices(
                         self.seq_lens[self.num_contexts:self.num_seqs],
                         kv_lens[self.num_contexts:self.num_seqs],
-                        self.num_tokens - self.num_ctx_tokens, device)
+                        self.num_tokens - self.num_ctx_tokens)

Also applies to: 612-615, 619-622, 631-634, 638-641

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 586 - 587,
The _get_dense_topk_indices function currently accepts a device parameter but
immediately overwrites it with kv_lens.device; remove the redundant device
parameter from the function signature and delete the line "device =
kv_lens.device", update all internal uses to rely on kv_lens.device directly,
and update every call site to stop passing device; apply the same change to the
sibling methods that mirror this pattern (the other definitions at the ranges
you noted) so their signatures and callers are consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Around line 1135-1140: The helper prepare_position_ids_and_last_tokens
(decorated with `@torch.compile`) currently takes attn_metadata and accesses
attn_metadata.seq_lens_cuda, which creates Dynamo guards; change its signature
to accept seq_lens_cuda (e.g., seq_lens_cuda: Tensor) instead of attn_metadata,
use seq_lens_cuda directly in the torch.cumsum call to compute last_tokens_idx,
and update any callers (including the similar call site around the other usage)
to pass attn_metadata.seq_lens_cuda rather than the whole attn_metadata object
so the compiled graph is stable.

---

Outside diff comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Line 1: This file (dsa.py) is missing the required NVIDIA OSS copyright
header; add the official NVIDIA copyright header (including the year of the
latest meaningful modification) at the very top of the file before any code or
imports (i.e., insert it above the existing import math line), ensuring the
header format matches other TensorRT-LLM files in the repo.

In `@tensorrt_llm/_torch/speculative/mtp.py`:
- Line 1: This file is missing the required NVIDIA OSS copyright header;
add/update a proper NVIDIA copyright header at the top of the module (above the
first import) that includes the year of the latest meaningful modification and
the standard NVIDIA OSS header text, so the tensorrt_llm._torch.speculative.mtp
module begins with the required NVIDIA copyright block before the existing
import sys line.

---

Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 586-587: The _get_dense_topk_indices function currently accepts a
device parameter but immediately overwrites it with kv_lens.device; remove the
redundant device parameter from the function signature and delete the line
"device = kv_lens.device", update all internal uses to rely on kv_lens.device
directly, and update every call site to stop passing device; apply the same
change to the sibling methods that mirror this pattern (the other definitions at
the ranges you noted) so their signatures and callers are consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 18d0ec17-1df3-4e43-a34b-fea417d352b2

📥 Commits

Reviewing files that changed from the base of the PR and between f0b336e and 2da139a.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/speculative/mtp.py

Comment thread tensorrt_llm/_torch/speculative/mtp.py
@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 30, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40702 [ run ] triggered by Bot. Commit: 2da139a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40702 [ run ] completed with state SUCCESS. Commit: 2da139a
/LLM/main/L0_MergeRequest_PR pipeline #31727 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 30, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40738 [ run ] triggered by Bot. Commit: 2da139a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40738 [ run ] completed with state SUCCESS. Commit: 2da139a
/LLM/main/L0_MergeRequest_PR pipeline #31758 completed with status: 'SUCCESS'

CI Report

Link to invocation

@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 30, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40793 [ run ] triggered by Bot. Commit: 687480c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40793 [ run ] completed with state SUCCESS. Commit: 687480c
/LLM/main/L0_MergeRequest_PR pipeline #31809 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 31, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40904 [ run ] triggered by Bot. Commit: a23570f Link to invocation

@hyukn hyukn requested review from mikeiovine and yuxianq March 31, 2026 09:14
Copy link
Copy Markdown
Collaborator

@mikeiovine mikeiovine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MTP changes look OK, I don't have context on the other stuff

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40904 [ run ] completed with state SUCCESS. Commit: a23570f
/LLM/main/L0_MergeRequest_PR pipeline #31905 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

self._pool_cache_valid = True

@maybe_compile(dynamic=True)
def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens, device):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device argument is unused.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed it.

Comment thread tensorrt_llm/_torch/speculative/mtp.py Outdated
Comment thread tensorrt_llm/_torch/attention_backend/sparse/kernel.py
@hyukn hyukn changed the title [https://nvbugs/5983390][perf] Reduce host overhead caused by torch.compile in speculative decoding. [https://nvbugs/5983390][perf] Multiple host perf optimizations for DSA part Apr 1, 2026
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41113 [ run ] completed with state SUCCESS. Commit: 344b8c6
/LLM/main/L0_MergeRequest_PR pipeline #32086 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hyukn hyukn requested a review from a team as a code owner April 1, 2026 16:24
@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Apr 1, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41226 [ run ] triggered by Bot. Commit: fc5eebd Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41226 [ run ] completed with state SUCCESS. Commit: fc5eebd
/LLM/main/L0_MergeRequest_PR pipeline #32186 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Apr 2, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41306 [ run ] triggered by Bot. Commit: fc5eebd Link to invocation

hyukn added 6 commits April 2, 2026 12:12
…ompile in speculative decoding.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
…are_pool_view. This reduces about 50us * 2 per layer, totally 4~5 ms in a forward step.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
@longlee0622 longlee0622 force-pushed the fix/5983390_torch_compile branch from fc5eebd to 66f6b07 Compare April 2, 2026 04:12
@longlee0622
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41332 [ run ] triggered by Bot. Commit: 66f6b07 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41306 [ run ] completed with state ABORTED. Commit: fc5eebd
/LLM/main/L0_MergeRequest_PR pipeline #32259 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator

lancelly commented Apr 2, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41411 [ run ] triggered by Bot. Commit: 66f6b07 Link to invocation

longlee0622 pushed a commit that referenced this pull request Apr 2, 2026
… optimizations for DSA part (#12681)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41411 [ run ] completed with state SUCCESS. Commit: 66f6b07
/LLM/main/L0_MergeRequest_PR pipeline #32346 completed with status: 'SUCCESS'

CI Report

Link to invocation

@longlee0622 longlee0622 merged commit edbb4b2 into NVIDIA:main Apr 2, 2026
7 checks passed
2ez4bz added a commit to 2ez4bz/TensorRT-LLM that referenced this pull request Apr 2, 2026
SimengLiu-nv pushed a commit to SimengLiu-nv/TensorRT-LLM that referenced this pull request Apr 6, 2026
…t perf optimizations for DSA part (NVIDIA#12681)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
karen-sy pushed a commit to karen-sy/TensorRT-LLM that referenced this pull request Apr 7, 2026
…SA part (NVIDIA#12581)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
dongfengy pushed a commit to dongfengy/TensorRT-LLM that referenced this pull request Apr 8, 2026
…t perf optimizations for DSA part (NVIDIA#12681)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
dongfengy pushed a commit to dongfengy/TensorRT-LLM that referenced this pull request Apr 10, 2026
…t perf optimizations for DSA part (NVIDIA#12681)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
dongfengy pushed a commit to dongfengy/TensorRT-LLM that referenced this pull request Apr 10, 2026
…t perf optimizations for DSA part (NVIDIA#12681)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.