Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[None][feat] Fuse all_reduce with norm for nemotron_h models#12410

Merged
Wanli-Jiang merged 1 commit intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
Wanli-Jiang:user/williamj/fuse_allreduce_normWanli-Jiang/TensorRT-LLM:user/williamj/fuse_allreduce_normCopy head branch name to clipboard
Mar 24, 2026
Merged

[None][feat] Fuse all_reduce with norm for nemotron_h models#12410
Wanli-Jiang merged 1 commit intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
Wanli-Jiang:user/williamj/fuse_allreduce_normWanli-Jiang/TensorRT-LLM:user/williamj/fuse_allreduce_normCopy head branch name to clipboard

Conversation

@Wanli-Jiang
Copy link
Copy Markdown
Collaborator

@Wanli-Jiang Wanli-Jiang commented Mar 20, 2026

Features

Fuses the AllReduce communication with RMS normalization + residual addition into a single kernel for Nemotron-H models when running with tensor parallelism (tp_size > 1).

How It Works

Before (unfused):

Layer N mixer → AllReduce → write to DRAM → read → residual add → RMS norm → Layer N+1 mixer

After (fused):

Layer N mixer → (unreduced output) → Layer N+1 pre_allreduce [AllReduce + residual + RMS norm in one kernel] → Layer N+1 mixer

Each layer's pre_allreduce handles the AllReduce of the previous layer's mixer output, fused with the current layer's norm. Layer 0 skips this (embedding output is already reduced). A final_allreduce handles the last layer's output → final norm.

All mixer types participate in the fusion:

  • Mamba / MLP: disable their own allreduce at init time (reduce_output=False)
  • Transformer / MoE: disable their own allreduce at forward time (AllReduceParams(enable_allreduce=False))

This ensures exactly one allreduce per layer, shifted to the next layer's pre_allreduce.

Scenarios That Benefit

Scenario Why It Benefits
TP >= 2, decode phase (small batch) AllReduce latency dominates; fusing eliminates 2-3 extra kernel launches and memory round-trips between them
All layer types (M, -, *, E) Every layer's reduction is deferred to the next layer's pre_allreduce, enabling fusion across the full model
NVFP4 quantization Fusion extends to RESIDUAL_RMS_NORM_QUANT_NVFP4 — norm output is quantized in-place, avoiding a separate quant kernel and memory round-trip
Memory-bandwidth-bound regimes Fused kernel keeps data in registers through allreduce → residual → norm, avoiding ~3-4x redundant DRAM reads/writes

Scenarios That Are Unchanged

Scenario Why
tp_size == 1 or enable_attention_dp Fusion is disabled entirely (fuse_allreduce_norm = False)
Large prefill (context phase) Compute dominates over communication; fusion savings are proportionally smaller
NCCL fallback path If the custom AllReduce strategy falls back to NCCL (e.g., cross-node), the fused kernel does not apply

Detailed Benefits

  1. Kernel launch reduction: Each layer saves 2-3 kernel launches (separate allreduce, residual add, rmsnorm → one fused kernel). For Nemotron-H with ~60+ layers, this is 120-180 fewer kernel launches per forward pass.

  2. Memory bandwidth savings: The fused kernel reads peer buffers, accumulates in registers, applies residual + norm, and writes once. Without fusion, intermediate results hit DRAM 3-4 times.

  3. NVFP4 quantization fusion: When NVFP4 is active, the norm output is quantized within the same kernel, saving yet another kernel + memory round-trip.

  4. Latency hiding: trigger_completion_at_end=False allows the allreduce to overlap its completion signaling with subsequent compute.

Summary by CodeRabbit

  • Refactor
    • Enhanced collective operation fusion in Nemotron model to improve distributed training efficiency.
    • Refined kernel selection logic in Mamba2 mixer based on tensor-parallel configuration compatibility for optimized performance on supported hardware.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@Wanli-Jiang
Copy link
Copy Markdown
Collaborator Author

/bot run --only-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39732 [ run ] triggered by Bot. Commit: d8ef936 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

Two targeted enhancements: NemotronH model gains AllReduce fusion support with RMSNorm via new parameters (reduce_output, fuse_allreduce_norm) and fusion logic in layer forward passes; Mamba2 mixer refines FlashInfer kernel selection by adding head_group_ratio compatibility requirement alongside existing head_dim validation.

Changes

Cohort / File(s) Summary
NemotronH AllReduce Fusion
tensorrt_llm/_torch/models/modeling_nemotron_h.py
Added reduce_output parameter to MLPLayer and NemotronHMOE; added fuse_allreduce_norm to NemotronHLayer with conditional AllReduceFusionOp fusion in forward pass; modified NemotronHModel to support environment-controlled fusion and coordinate AllReduce with RMSNorm at layer and model levels.
Mamba2 Mixer Kernel Selection
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Refined FlashInfer kernel selection predicate to require head_group_ratio compatibility (must be in [1, 8, 16]) in addition to existing head_dim validation.

Sequence Diagram(s)

sequenceDiagram
    participant Layer as NemotronHLayer
    participant Norm as RMSNorm
    participant Fusion as AllReduceFusionOp
    participant Mixer as Mixer Layer
    
    rect rgba(100, 150, 200, 0.5)
    Note over Layer,Mixer: With fuse_allreduce_norm enabled
    Layer->>Norm: Input tensor
    Norm->>Fusion: Normalized output + hp_output
    Fusion->>Mixer: Fused AllReduce-Norm output
    Mixer->>Layer: Mixer output
    end
    
    rect rgba(200, 150, 100, 0.5)
    Note over Layer,Mixer: With fuse_allreduce_norm disabled
    Layer->>Norm: Input tensor
    Norm->>Layer: Normalized output
    Layer->>Mixer: Pre-norm output
    Mixer->>Layer: Mixer output
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning PR description is incomplete. While a detailed technical explanation of the feature is present under 'Features', the required template sections (Title, Description, and Test Coverage) are not properly filled. Add a proper PR title following [type] format (e.g., '[None][feat] Fuse all_reduce with norm for nemotron_h models'), complete the Description section with a concise summary, and list specific test cases that validate the changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main feature: fusing all_reduce operations with normalization for Nemotron-H models, which aligns with the code changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Bump the NVIDIA header year for this modified file.

The SPDX copyright line still ends at 2024 even though this file changes in this PR. Please update it to 2026.

As per coding guidelines: "Add NVIDIA copyright header on ALL new files, and update year on modified files."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py` at line 1, Update the SPDX
copyright header year in the file by modifying the existing SPDX line that
currently reads "SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA
CORPORATION & AFFILIATES. All rights reserved." to reflect the new modification
year 2026; locate that SPDX header at the top of
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py and change the end year from
2024 to 2026 so it reads "...2022-2026 NVIDIA CORPORATION & AFFILIATES. All
rights reserved."
tensorrt_llm/_torch/models/modeling_nemotron_h.py (1)

233-240: ⚠️ Potential issue | 🔴 Critical

Keep the attention-DP guard on the MoE output all-reduce.

Lines 575-577 force fuse_allreduce_norm off when enable_attention_dp is on, so this constructor now sees reduce_output=True in exactly the mode where Lines 553-556 say all-reduce is invalid. AllReduce.forward() in tensorrt_llm/_torch/distributed/ops.py does not suppress that case for you, so Lines 331-332 can hang or shape-fail on multi-rank MoE layers unless you preserve the old not self.enable_attention_dp check here.

🛠️ Proposed fix
-        if reduce_output:
+        if reduce_output and not self.enable_attention_dp:
             # AllReduce for combining shared and routed expert outputs in multi-GPU settings.
             self.allreduce = AllReduce(
                 mapping=model_config.mapping,
                 strategy=model_config.allreduce_strategy,
             )
         else:
             self.allreduce = None

Also applies to: 331-332

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py` around lines 233 - 240,
Constructor currently sets self.allreduce when reduce_output is true even if
attention data-parallel is enabled; restore the attention-DP guard so AllReduce
is only created when reduce_output is true AND attention DP is disabled. Modify
the constructor logic around reduce_output/AllReduce (the block that assigns
self.allreduce and uses AllReduce(mapping=..., strategy=...)) to also check
self.enable_attention_dp (negated) before instantiating AllReduce; this
preserves the previous behavior that avoids calling AllReduce.forward() in the
attention-DP mode (see related checks around lines invoking fuse_allreduce_norm
and the AllReduce.forward usage).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Around line 145-148: In __init__, validate tensor-parallel divisibility before
computing tp_ngroups, tp_d_inner/group_size, and head_group_ratio: ensure
n_groups is evenly divisible by tp_size (so tp_ngroups = n_groups // tp_size is
exact and >0), ensure self.tp_d_inner is divisible by tp_ngroups (so group_size
= self.tp_d_inner // self.tp_ngroups is valid), and ensure self.tp_nheads is
divisible by tp_ngroups (so head_group_ratio = self.tp_nheads // self.tp_ngroups
is exact); if any check fails either raise a clear ValueError or fall back to
disabling FlashInfer (set self._use_flashinfer = False) instead of proceeding.
Update the head_group_ratio computation to use the validated tp_ngroups (no
conditional truncation) and only consider FlashInfer/supported_head_group_ratios
when all divisibility checks succeeded so RMSNormGated and kernel dispatch
receive valid group_size and ratios.

---

Outside diff comments:
In `@tensorrt_llm/_torch/models/modeling_nemotron_h.py`:
- Around line 233-240: Constructor currently sets self.allreduce when
reduce_output is true even if attention data-parallel is enabled; restore the
attention-DP guard so AllReduce is only created when reduce_output is true AND
attention DP is disabled. Modify the constructor logic around
reduce_output/AllReduce (the block that assigns self.allreduce and uses
AllReduce(mapping=..., strategy=...)) to also check self.enable_attention_dp
(negated) before instantiating AllReduce; this preserves the previous behavior
that avoids calling AllReduce.forward() in the attention-DP mode (see related
checks around lines invoking fuse_allreduce_norm and the AllReduce.forward
usage).

In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Line 1: Update the SPDX copyright header year in the file by modifying the
existing SPDX line that currently reads "SPDX-FileCopyrightText: Copyright (c)
2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved." to reflect the
new modification year 2026; locate that SPDX header at the top of
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py and change the end year from
2024 to 2026 so it reads "...2022-2026 NVIDIA CORPORATION & AFFILIATES. All
rights reserved."

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 196b529f-4836-4be7-a0fd-7a0be6441e56

📥 Commits

Reviewing files that changed from the base of the PR and between a7e22b9 and d8ef936.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/models/modeling_nemotron_h.py
  • tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Comment thread tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39732 [ run ] completed with state SUCCESS. Commit: d8ef936
/LLM/main/L0_MergeRequest_PR pipeline #30927 (Partly Tested) completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/fuse_allreduce_norm branch from d8ef936 to b827ea8 Compare March 23, 2026 07:09
@Wanli-Jiang
Copy link
Copy Markdown
Collaborator Author

/bot run --only-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39898 [ run ] triggered by Bot. Commit: b827ea8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39898 [ run ] completed with state SUCCESS. Commit: b827ea8
/LLM/main/L0_MergeRequest_PR pipeline #31067 (Partly Tested) completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@Wanli-Jiang
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40010 [ run ] triggered by Bot. Commit: b827ea8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40010 [ run ] completed with state SUCCESS. Commit: b827ea8
/LLM/main/L0_MergeRequest_PR pipeline #31168 completed with status: 'SUCCESS'

CI Report

Link to invocation

@Wanli-Jiang Wanli-Jiang merged commit e5d8435 into NVIDIA:main Mar 24, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.