[TRTLLM-10386][fix] torch.compile: register add+norm fallback pass in multi-GPU mode#11739
[TRTLLM-10386][fix] torch.compile: register add+norm fallback pass in multi-GPU mode#11739yizhang-nv merged 1 commit intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom luyiyun1021:dev-fix-add-norm-fallback-multi-gpuluyiyun1021/TensorRT-LLM:dev-fix-add-norm-fallback-multi-gpuCopy head branch name to clipboard
Conversation
In multi-GPU mode, recover_pass splits fused_add_rmsnorm but only allreduce-based fusion patterns were registered. Layers without allreduce (attention DP, hybrid Mamba, PP-only) remained unfused. Add register_add_norm as a fallback pass after allreduce fusions. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughA fallback optimization pass is added to the custom compilation pass sequence in the multi-process environment (world_size > 1). After the initial allreduce fusion setup, a second PatternMatcherPass with register_add_norm is appended to fuse additional add+rmsnorm operations that lack preceding allreduce operations. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
|
/bot run --disable-fail-fast |
|
PR_Github #36886 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #36912 [ run ] triggered by Bot. Commit: |
|
PR_Github #36912 [ run ] completed with state |
… multi-GPU mode (NVIDIA#11739) Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
… multi-GPU mode (NVIDIA#11739) Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
Summary by CodeRabbit
Description
In multi-GPU mode (
world_size > 1), the torch.compile customBackendusesrecover_passto splitfused_add_rmsnorminto separateadd+rmsnormops, then re-fuses them via pattern matching. However, only allreduce-based fusion patterns (register_ar_fusions) were registered. The simplerregister_add_normfallback was missing.This causes layers without an explicit
allreducebeforeadd + rmsnormto remain permanently unfused — resulting in 2 kernel launches instead of 1, with ~24-30% more GPU time for these operations.Affected scenarios:
enable_attention_dp=True— attention layers skip allreduce (Llama, Qwen3, DeepSeekV3, etc.)PP > 1, TP = 1) —mpi_world_size() > 1but no allreduce anywherePRE_MLP_FUSION,POST_MLP_FUSION) — disable allreduce on specific layersFix: Add
register_add_normas a fallbackPatternMatcherPassin the multi-GPU branch, placed after all allreduce-based fusion passes to ensure correct priority (2 lines added inbackend.py).Test Coverage
Unit tests:
tests/unittest/_torch/compilation/test_add_norm.py— 4 tests passed (single-GPUregister_add_normcorrectness, both dtypes, with/without inductor). Confirms no regression.nsys profiling verification (2x NVIDIA B200):
All three testable scenarios verified with before/after nsys profiling using
trtllm-serve:enable_attention_dp=Truefullgraph=FalsePR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.