Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[TRTLLM-11064][fix] Remove duplicated MoE Computation with Helix CP+DP#11167

Merged
brb-nv merged 2 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
brb-nv:user/brb/deduplicate-moe-computation-mrbrb-nv/TensorRT-LLM:user/brb/deduplicate-moe-computation-mrCopy head branch name to clipboard
Feb 27, 2026
Merged

[TRTLLM-11064][fix] Remove duplicated MoE Computation with Helix CP+DP#11167
brb-nv merged 2 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
brb-nv:user/brb/deduplicate-moe-computation-mrbrb-nv/TensorRT-LLM:user/brb/deduplicate-moe-computation-mrCopy head branch name to clipboard

Conversation

@brb-nv
Copy link
Copy Markdown
Collaborator

@brb-nv brb-nv commented Feb 1, 2026

Description

Background

When using Helix CP with Attention DP on the generation server, all CP ranks within the same DP group have identical data after attention computation. The MoE layer receives token counts from all ranks via tp_cp_allgather to coordinate expert routing and communication.

Issue

MoE computation was being duplicated across CP ranks. Each CP rank was processing the same tokens independently, resulting in redundant computation and poor perf while maintaining accuracy.

Root Cause

The tp_cp_allgather operation gathers token counts from all TP×CP ranks. In Helix mode, CP ranks hold identical data post-attention, so the gathered token counts included duplicates. For example, with CP=2:

  • [dp0cp0=128, dp0cp1=128, dp1cp0=64, dp1cp1=64]

The MoE layer would then process 128 tokens on both dp0cp0 and dp0cp1, doubling the computation.

Fix

Allgather at the beginning of MLA layer and ReduceScatter at the end of it when AttentionDP is enabled.

Test Coverage

$ pytest tests/integration/defs/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1dp2cp2] -s -v

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Summary by CodeRabbit

Release Notes

  • New Features
    • Added support for Helix Context Parallelism in DeepseekV3 models, enabling optimized distributed inference with automatic token deduplication across compute ranks.

✏️ Tip: You can customize this high-level summary in your review settings.

@brb-nv brb-nv changed the title [None][fix] Remove Duplicated MoE Computation in Helix CP + DP Mode [None][fix] Remove duplicated MoE Computation with Helix CP+DP Feb 1, 2026
@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch 2 times, most recently from 382e920 to 1485b6d Compare February 1, 2026 01:42
@brb-nv brb-nv marked this pull request as ready for review February 1, 2026 01:43
@brb-nv brb-nv requested review from a team as code owners February 1, 2026 01:43
@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Feb 1, 2026

/bot run --disable-fail-fast

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 1, 2026

📝 Walkthrough

Walkthrough

Added Helix Context Parallelism deduplication support to DeepseekV3 model and PyTorch executor. Threads mapping_with_cp through model layers, implements Helix-aware input preparation and output broadcasting, and deduplicates token counts across ranks.

Changes

Cohort / File(s) Summary
DeepseekV3 Model Enhancements
tensorrt_llm/_torch/models/modeling_deepseekv3.py
Added mapping_with_cp parameter to constructors of DeepseekV3MoE, DeepseekV3Attention, DeepseekV3DecoderLayer, and DeepseekV3Model. Implemented Helix-aware methods: _is_helix_with_cp() for mode detection, _prepare_helix_inputs() for input slicing on non-cp_rank_0, and _broadcast_helix_output() for output gathering. Modified MoE routing to handle empty hidden states and guarded gate computation logic. Added cp_allgather import for distributed operations.
Executor Token Deduplication
tensorrt_llm/_torch/pyexecutor/model_engine.py
Added _apply_helix_deduplication() method to PyTorchModelEngine to deduplicate per-rank token counts in Helix CP mode by zeroing non-cp_rank_0 ranks within each DP group. Integrated deduplication into _get_all_rank_num_tokens() when attention data-parallelism is enabled.

Sequence Diagram

sequenceDiagram
    participant Executor as PyTorchModelEngine
    participant Model as DeepseekV3Model
    participant Layer as DecoderLayer
    participant MoE as MoE/Attention
    participant Dist as Distributed Ops

    Executor->>Executor: _get_all_rank_num_tokens()
    Executor->>Executor: _apply_helix_deduplication(all_rank_num_tokens)
    Note over Executor: Zero out non-cp_rank_0<br/>within each DP group
    
    Executor->>Model: Forward pass with dedup tokens
    Model->>Layer: Forward (per layer)
    Layer->>Layer: _is_helix_with_cp()?
    alt Helix CP Mode Enabled
        Layer->>Layer: _prepare_helix_inputs(hidden_states)
        Note over Layer: Slice inputs for cp_rank > 0<br/>(empty tensors for non-rank_0)
        Layer->>MoE: Forward with prepared inputs
        MoE->>MoE: compute_routed_output()
        Note over MoE: Handle empty hidden_states<br/>Skip gate when empty
        MoE->>Dist: cp_allgather(routed_output)
        Dist->>Layer: Gathered outputs
        Layer->>Layer: _broadcast_helix_output(outputs)
    else Non-Helix Mode
        Layer->>MoE: Forward normally
        MoE->>MoE: Standard routing logic
    end
    Layer->>Model: Return processed output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 41.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: removing duplicated MoE computation in Helix CP+DP configuration by deduplicating token counts and refactoring MoE routing.
Description check ✅ Passed The PR description comprehensively explains the background, issue, root cause, and implemented fix with clear examples and test coverage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/models/modeling_deepseekv3.py`:
- Around line 1070-1076: The empty-tensor branch creates router_logits with
hidden_states.dtype which can differ from the gate output
(DeepseekV3Gate.forward uses out_dtype=torch.float32); change the empty
router_logits creation to use dtype=torch.float32 (keep
device=hidden_states.device and shape (0, num_experts)) so the dtype matches
gate logits and MoE collectives; update the code that builds router_logits (the
block referencing self.gate / self.experts and hidden_states) to explicitly set
torch.float32 as the dtype.
- Around line 1048-1055: The DP-padding block guarded by self.use_dp,
self.mapping.tp_size > 1 and get_sm_version() == 120 must also check for
non-empty local token batch to avoid padding zero-token ranks; wrap the existing
torch.nn.functional.pad call (the hidden_states padding) in a conditional if
hidden_states.shape[0] > 0 so that when hidden_states is empty (zero-token CP
ranks) no padding is applied and the later MoE routing (router_logits, expert
dispatch code paths) is not triggered; locate the block using the symbols
self.use_dp, self.mapping.tp_size, get_sm_version(), and hidden_states and add
the if-check around the padding call.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

50-52: Prefer module-qualified import for the new cp_allgather usage.

To keep the Helix additions aligned with the namespace-import guideline, consider importing the distributed module and referencing cp_allgather via that module (e.g., dist.cp_allgather) instead of expanding the from-import list. Remember to update the call site in _broadcast_helix_output.

Proposed change (imports)
-from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
-                           MoEAllReduce, MoEAllReduceParams, allgather,
-                           cp_allgather)
+from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
+                           MoEAllReduce, MoEAllReduceParams, allgather)
+from .. import distributed as dist

As per coding guidelines: Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.

Comment thread tensorrt_llm/_torch/models/modeling_deepseekv3.py Outdated
Comment thread tensorrt_llm/_torch/models/modeling_deepseekv3.py Outdated
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #34329 [ run ] triggered by Bot. Commit: 1485b6d

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #34329 [ run ] completed with state SUCCESS. Commit: 1485b6d
/LLM/main/L0_MergeRequest_PR pipeline #26478 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Feb 5, 2026

/bot run --disable-fail-fast

@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch from e0e35b9 to 57eff37 Compare February 6, 2026 23:37
@brb-nv brb-nv requested a review from a team as a code owner February 9, 2026 01:37
@brb-nv brb-nv requested a review from yuxianq February 9, 2026 01:37
@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch 2 times, most recently from 1d33e5d to 57579eb Compare February 10, 2026 22:13
@brb-nv brb-nv requested a review from a team as a code owner February 10, 2026 22:40
@brb-nv brb-nv requested a review from syuoni February 10, 2026 22:40
@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch 2 times, most recently from 53c8ff0 to c6a82b6 Compare February 10, 2026 23:12
@brb-nv brb-nv removed request for a team and yuxianq February 10, 2026 23:50
@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch 2 times, most recently from e2e85d4 to 6e53b2c Compare February 10, 2026 23:52
@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Feb 10, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36196 [ run ] triggered by Bot. Commit: b3bd411 Link to invocation

@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch from b3bd411 to 49e38d7 Compare February 19, 2026 03:49
@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Feb 19, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36202 [ run ] triggered by Bot. Commit: 49e38d7 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36202 [ run ] completed with state SUCCESS. Commit: 49e38d7
/LLM/main/L0_MergeRequest_PR pipeline #27983 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Feb 20, 2026

/bot run --disable-fail-fast

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch from 49e38d7 to c1462c2 Compare February 20, 2026 06:32
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36320 [ run ] triggered by Bot. Commit: c1462c2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36320 [ run ] completed with state SUCCESS. Commit: c1462c2
/LLM/main/L0_MergeRequest_PR pipeline #28092 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Feb 20, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36350 [ run ] triggered by Bot. Commit: c1462c2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36350 [ run ] completed with state SUCCESS. Commit: c1462c2
/LLM/main/L0_MergeRequest_PR pipeline #28118 completed with status: 'SUCCESS'

Link to invocation

Comment thread tensorrt_llm/_torch/modules/attention.py Outdated
Comment thread tensorrt_llm/_torch/models/modeling_speculative.py Outdated
Comment thread tensorrt_llm/_torch/modules/attention.py Outdated
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
@brb-nv brb-nv force-pushed the user/brb/deduplicate-moe-computation-mr branch from 5d5a594 to 7875a69 Compare February 26, 2026 19:43
@brb-nv brb-nv requested a review from yuxianq February 26, 2026 19:45
@brb-nv
Copy link
Copy Markdown
Collaborator Author

brb-nv commented Feb 26, 2026

/bot run --disable-fail-fast

@brb-nv brb-nv enabled auto-merge (squash) February 26, 2026 19:51
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36964 [ run ] triggered by Bot. Commit: 7875a69 Link to invocation

Comment thread tensorrt_llm/_torch/models/modeling_deepseekv3.py
Comment thread tensorrt_llm/_torch/modules/attention.py
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #36964 [ run ] completed with state SUCCESS. Commit: 7875a69
/LLM/main/L0_MergeRequest_PR pipeline #28621 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.
Pipeline has performance regression cases. Check the performance regression report for details.

Link to invocation

@brb-nv brb-nv merged commit 2237d7d into NVIDIA:main Feb 27, 2026
7 checks passed
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Mar 9, 2026
NVIDIA#11167)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
tianyuz-nv pushed a commit to wanqian-nv/TensorRT-LLM that referenced this pull request Mar 19, 2026
NVIDIA#11167)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.