[TRTLLM-11064][fix] Remove duplicated MoE Computation with Helix CP+DP#11167
[TRTLLM-11064][fix] Remove duplicated MoE Computation with Helix CP+DP#11167brb-nv merged 2 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom brb-nv:user/brb/deduplicate-moe-computation-mrbrb-nv/TensorRT-LLM:user/brb/deduplicate-moe-computation-mrCopy head branch name to clipboard
Conversation
382e920 to
1485b6d
Compare
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughAdded Helix Context Parallelism deduplication support to DeepseekV3 model and PyTorch executor. Threads mapping_with_cp through model layers, implements Helix-aware input preparation and output broadcasting, and deduplicates token counts across ranks. Changes
Sequence DiagramsequenceDiagram
participant Executor as PyTorchModelEngine
participant Model as DeepseekV3Model
participant Layer as DecoderLayer
participant MoE as MoE/Attention
participant Dist as Distributed Ops
Executor->>Executor: _get_all_rank_num_tokens()
Executor->>Executor: _apply_helix_deduplication(all_rank_num_tokens)
Note over Executor: Zero out non-cp_rank_0<br/>within each DP group
Executor->>Model: Forward pass with dedup tokens
Model->>Layer: Forward (per layer)
Layer->>Layer: _is_helix_with_cp()?
alt Helix CP Mode Enabled
Layer->>Layer: _prepare_helix_inputs(hidden_states)
Note over Layer: Slice inputs for cp_rank > 0<br/>(empty tensors for non-rank_0)
Layer->>MoE: Forward with prepared inputs
MoE->>MoE: compute_routed_output()
Note over MoE: Handle empty hidden_states<br/>Skip gate when empty
MoE->>Dist: cp_allgather(routed_output)
Dist->>Layer: Gathered outputs
Layer->>Layer: _broadcast_helix_output(outputs)
else Non-Helix Mode
Layer->>MoE: Forward normally
MoE->>MoE: Standard routing logic
end
Layer->>Model: Return processed output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/models/modeling_deepseekv3.py`:
- Around line 1070-1076: The empty-tensor branch creates router_logits with
hidden_states.dtype which can differ from the gate output
(DeepseekV3Gate.forward uses out_dtype=torch.float32); change the empty
router_logits creation to use dtype=torch.float32 (keep
device=hidden_states.device and shape (0, num_experts)) so the dtype matches
gate logits and MoE collectives; update the code that builds router_logits (the
block referencing self.gate / self.experts and hidden_states) to explicitly set
torch.float32 as the dtype.
- Around line 1048-1055: The DP-padding block guarded by self.use_dp,
self.mapping.tp_size > 1 and get_sm_version() == 120 must also check for
non-empty local token batch to avoid padding zero-token ranks; wrap the existing
torch.nn.functional.pad call (the hidden_states padding) in a conditional if
hidden_states.shape[0] > 0 so that when hidden_states is empty (zero-token CP
ranks) no padding is applied and the later MoE routing (router_logits, expert
dispatch code paths) is not triggered; locate the block using the symbols
self.use_dp, self.mapping.tp_size, get_sm_version(), and hidden_states and add
the if-check around the padding call.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)
50-52: Prefer module-qualified import for the newcp_allgatherusage.To keep the Helix additions aligned with the namespace-import guideline, consider importing the distributed module and referencing
cp_allgathervia that module (e.g.,dist.cp_allgather) instead of expanding the from-import list. Remember to update the call site in_broadcast_helix_output.Proposed change (imports)
-from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, - MoEAllReduce, MoEAllReduceParams, allgather, - cp_allgather) +from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, + MoEAllReduce, MoEAllReduceParams, allgather) +from .. import distributed as distAs per coding guidelines: Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.
|
PR_Github #34329 [ run ] triggered by Bot. Commit: |
|
PR_Github #34329 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
e0e35b9 to
57eff37
Compare
1d33e5d to
57579eb
Compare
53c8ff0 to
c6a82b6
Compare
e2e85d4 to
6e53b2c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36196 [ run ] triggered by Bot. Commit: |
b3bd411 to
49e38d7
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36202 [ run ] triggered by Bot. Commit: |
|
PR_Github #36202 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
49e38d7 to
c1462c2
Compare
|
PR_Github #36320 [ run ] triggered by Bot. Commit: |
|
PR_Github #36320 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #36350 [ run ] triggered by Bot. Commit: |
|
PR_Github #36350 [ run ] completed with state |
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
5d5a594 to
7875a69
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36964 [ run ] triggered by Bot. Commit: |
|
PR_Github #36964 [ run ] completed with state |
NVIDIA#11167) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
NVIDIA#11167) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Description
Background
When using Helix CP with Attention DP on the generation server, all CP ranks within the same DP group have identical data after attention computation. The MoE layer receives token counts from all ranks via
tp_cp_allgatherto coordinate expert routing and communication.Issue
MoE computation was being duplicated across CP ranks. Each CP rank was processing the same tokens independently, resulting in redundant computation and poor perf while maintaining accuracy.
Root Cause
The
tp_cp_allgatheroperation gathers token counts from all TP×CP ranks. In Helix mode, CP ranks hold identical data post-attention, so the gathered token counts included duplicates. For example, with CP=2:[dp0cp0=128, dp0cp1=128, dp1cp0=64, dp1cp1=64]The MoE layer would then process 128 tokens on both
dp0cp0anddp0cp1, doubling the computation.Fix
Allgather at the beginning of MLA layer and ReduceScatter at the end of it when AttentionDP is enabled.
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.