Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[TRTLLM-10386][fix] torch.compile: register add+norm fallback pass in multi-GPU mode#11739

Merged
yizhang-nv merged 1 commit intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
luyiyun1021:dev-fix-add-norm-fallback-multi-gpuluyiyun1021/TensorRT-LLM:dev-fix-add-norm-fallback-multi-gpuCopy head branch name to clipboard
Feb 27, 2026
Merged

[TRTLLM-10386][fix] torch.compile: register add+norm fallback pass in multi-GPU mode#11739
yizhang-nv merged 1 commit intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
luyiyun1021:dev-fix-add-norm-fallback-multi-gpuluyiyun1021/TensorRT-LLM:dev-fix-add-norm-fallback-multi-gpuCopy head branch name to clipboard

Conversation

@luyiyun1021
Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 commented Feb 26, 2026

Summary by CodeRabbit

  • New Features
    • Improved compiler optimization for distributed multi-GPU inference. Enhanced handling of normalization operations during model compilation provides performance improvements for multi-GPU deployments.

Description

In multi-GPU mode (world_size > 1), the torch.compile custom Backend uses recover_pass to split fused_add_rmsnorm into separate add + rmsnorm ops, then re-fuses them via pattern matching. However, only allreduce-based fusion patterns (register_ar_fusions) were registered. The simpler register_add_norm fallback was missing.

This causes layers without an explicit allreduce before add + rmsnorm to remain permanently unfused — resulting in 2 kernel launches instead of 1, with ~24-30% more GPU time for these operations.

Affected scenarios:

  • enable_attention_dp=True — attention layers skip allreduce (Llama, Qwen3, DeepSeekV3, etc.)
  • Hybrid Mamba+Transformer models (Nemotron-H) — Mamba layers have allreduce hidden inside Linear, not as FX graph node
  • Pipeline-parallel-only setups (PP > 1, TP = 1) — mpi_world_size() > 1 but no allreduce anywhere
  • Fusion config flags (PRE_MLP_FUSION, POST_MLP_FUSION) — disable allreduce on specific layers

Fix: Add register_add_norm as a fallback PatternMatcherPass in the multi-GPU branch, placed after all allreduce-based fusion passes to ensure correct priority (2 lines added in backend.py).

Test Coverage

Unit tests:

  • Existing tests/unittest/_torch/compilation/test_add_norm.py — 4 tests passed (single-GPU register_add_norm correctness, both dtypes, with/without inductor). Confirms no regression.

nsys profiling verification (2x NVIDIA B200):

All three testable scenarios verified with before/after nsys profiling using trtllm-serve:

Scenario Model Config Unfused kernel calls Fused kernel calls GPU time improvement
Attention DP Qwen3-30B-A3B TP=2, enable_attention_dp=True 960 add + 970 norm 960 FusedAddRMSNorm -23.7%
Nemotron-H Nemotron-H-8B-FP8 TP=2, fullgraph=False 280 add + 280 norm 280 FusedAddRMSNorm -29.6%
Pipeline Parallel Qwen3-30B-A3B PP=2, TP=1 480 add + 485 norm 480 FusedAddRMSNorm -29.2%

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

In multi-GPU mode, recover_pass splits fused_add_rmsnorm but only allreduce-based fusion patterns were registered. Layers without allreduce (attention DP, hybrid Mamba, PP-only) remained unfused. Add register_add_norm as a fallback pass after allreduce fusions.

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
@luyiyun1021 luyiyun1021 requested a review from a team as a code owner February 26, 2026 07:44
@luyiyun1021 luyiyun1021 requested a review from hyukn February 26, 2026 07:44
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 26, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eba1b54 and 595f796.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/compilation/backend.py

📝 Walkthrough

Walkthrough

A fallback optimization pass is added to the custom compilation pass sequence in the multi-process environment (world_size > 1). After the initial allreduce fusion setup, a second PatternMatcherPass with register_add_norm is appended to fuse additional add+rmsnorm operations that lack preceding allreduce operations.

Changes

Cohort / File(s) Summary
Compilation optimization pass
tensorrt_llm/_torch/compilation/backend.py
Added a fallback PatternMatcherPass with register_add_norm in get_custom_pass for multi-process environments, enabling fusion of add+rmsnorm operations not preceded by allreduce.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: registering an add+norm fallback pass in multi-GPU mode for torch.compile, directly matching the core fix in the changeset.
Description check ✅ Passed The PR description comprehensively covers all required template sections: clear explanation of the issue and solution, detailed test coverage with specific results, and completion of all PR checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36886 [ run ] triggered by Bot. Commit: 595f796 Link to invocation

@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36912 [ run ] triggered by Bot. Commit: 595f796 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36912 [ run ] completed with state SUCCESS. Commit: 595f796
/LLM/main/L0_MergeRequest_PR pipeline #28579 completed with status: 'SUCCESS'

Link to invocation

Copy link
Copy Markdown
Member

@yizhang-nv yizhang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yizhang-nv yizhang-nv merged commit 37ab642 into NVIDIA:main Feb 27, 2026
7 checks passed
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Mar 9, 2026
… multi-GPU mode (NVIDIA#11739)

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
tianyuz-nv pushed a commit to wanqian-nv/TensorRT-LLM that referenced this pull request Mar 19, 2026
… multi-GPU mode (NVIDIA#11739)

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.