Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[https://nvbugs/5983390][fix] Remove redundant D2H sync to optimize perf#12445

Merged
hyukn merged 1 commit intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
hyukn:fix/5983390_remove_synchyukn/TensorRT-LLM:fix/5983390_remove_syncCopy head branch name to clipboard
Mar 24, 2026
Merged

[https://nvbugs/5983390][fix] Remove redundant D2H sync to optimize perf#12445
hyukn merged 1 commit intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
hyukn:fix/5983390_remove_synchyukn/TensorRT-LLM:fix/5983390_remove_syncCopy head branch name to clipboard

Conversation

@hyukn
Copy link
Copy Markdown
Collaborator

@hyukn hyukn commented Mar 23, 2026

Description

_compute_slot_mappings() guards a debug assert with not is_current_stream_capturing(), which only skips it during CUDA graph capture. During eager execution, on_update_kv_lens() calls this function with GPU tensors from _preprocess_inputs() (model_engine.py:1507,1543), and .all() triggers a reduce_kernel<bool> + 1-byte D2H memcpy + cudaStreamSynchronize that stalls 12-15ms waiting for the GPU queue to drain. This happens twice per context forward step (once unconditionally, once for overlap scheduler + MTP), adding ~25ms of GPU bubble per iteration with context requests. Decode-only steps are unaffected.

Fix: Guard with not block_indices_in_seq.is_cuda instead. The assert still fires on the CPU path (Indexer.prepare()) at zero cost, but is skipped on the GPU path (on_update_kv_lens()) where block offsets were already validated and clamped during prepare().

nsys evidence (DeepSeek-V3.2, TP8, piecewise CUDA graph, c16):

  • Context forward steps: 2 syncs/GPU, 12-15ms each → ~25ms total bubble
  • Decode-only forward steps: 0 syncs

Test Coverage

Existing DSA accuracy tests cover correctness. The change only removes a redundant assert on the GPU path; the CPU-path assert is unchanged.

PR Checklist

  • Please check this after reviewing the above items as appropriate for this PR.

Summary by CodeRabbit

  • Bug Fixes
    • Fixed sparse attention backend assertion logic to properly determine when to skip bounds checking based on tensor device placement instead of CUDA stream capture state.

…erf.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 23, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39905 [ run ] triggered by Bot. Commit: 2ef0662 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39905 [ run ] completed with state SUCCESS. Commit: 2ef0662
/LLM/main/L0_MergeRequest_PR pipeline #31074 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 24, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 24, 2026

/bot --help

@github-actions
Copy link
Copy Markdown

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40011 [ run ] triggered by Bot. Commit: 2ef0662 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40011 [ run ] completed with state SUCCESS. Commit: 2ef0662
/LLM/main/L0_MergeRequest_PR pipeline #31169 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hyukn hyukn requested a review from yuxianq March 24, 2026 03:56
@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 24, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40047 [ run ] triggered by Bot. Commit: 2ef0662 Link to invocation

@hyukn hyukn marked this pull request as ready for review March 24, 2026 04:48
@hyukn hyukn requested a review from a team as a code owner March 24, 2026 04:48
@hyukn hyukn requested a review from pengbowang-nv March 24, 2026 04:48
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 24, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f8163aa3-17d1-4a9d-ba54-17deef34c070

📥 Commits

Reviewing files that changed from the base of the PR and between 7aa1383 and 2ef0662.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py

📝 Walkthrough

Walkthrough

A condition in _compute_slot_mappings was modified to check whether block_indices_in_seq resides on CUDA rather than whether the current CUDA stream is in graph-capture mode, affecting when an out-of-bounds assertion is skipped.

Changes

Cohort / File(s) Summary
CUDA Device Placement Logic
tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Replaced torch.cuda.is_current_stream_capturing() check with block_indices_in_seq.is_cuda in _compute_slot_mappings, changing the condition for skipping out-of-bounds assertion from stream capture mode detection to tensor device placement.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the fix (removing a redundant D2H sync) and references the NVBugs ticket, matching the description's core objective.
Description check ✅ Passed The description includes all required sections: a clear explanation of the issue and solution, test coverage details, and a completed PR checklist.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40047 [ run ] completed with state SUCCESS. Commit: 2ef0662
/LLM/main/L0_MergeRequest_PR pipeline #31201 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@hyukn
Copy link
Copy Markdown
Collaborator Author

hyukn commented Mar 24, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40132 [ run ] triggered by Bot. Commit: 2ef0662 Link to invocation

@hyukn hyukn enabled auto-merge (squash) March 24, 2026 14:48
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40132 [ run ] completed with state SUCCESS. Commit: 2ef0662
/LLM/main/L0_MergeRequest_PR pipeline #31279 completed with status: 'SUCCESS'

CI Report

Link to invocation

@hyukn hyukn merged commit d8eb3a6 into NVIDIA:main Mar 24, 2026
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.