[TRTLLM-11532][refactor] Unify VisualGen parallelism#12509
[TRTLLM-11532][refactor] Unify VisualGen parallelism#12509chang-l merged 5 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom NVShreyas:user/shreyasm/mapping-refactorNVShreyas/TensorRT-LLM:user/shreyasm/mapping-refactorCopy head branch name to clipboard
Conversation
📝 WalkthroughWalkthroughThis refactoring centralizes visual generation parallelism configuration into a new Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (4)
tensorrt_llm/_torch/visual_gen/__init__.py (1)
24-24: Use a module import for the new public export.This adds another direct symbol import into the package root. Please import
.mappingand re-exportmapping.VisualGenMappinginstead so this file stays aligned with the repo's namespace-import rule.As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/__init__.py` at line 24, Replace the direct symbol import in the package root by importing the module and re-exporting the class via the module namespace: change the current "from .mapping import VisualGenMapping" to import the module (e.g., "from . import mapping") and then expose mapping.VisualGenMapping in this package's __all__ or by assigning VisualGenMapping = mapping.VisualGenMapping; update the module-level export so the package root still provides VisualGenMapping but preserves namespace-import style.tests/unittest/_torch/visual_gen/test_wan_i2v.py (1)
287-289: This worker still assumes the default CFG/Ulysses rank ordering.The assertion now reads
cfg_sizefromvisual_gen_mapping, but Line 310 still derivesexpected_cfg_groupasrank // cfg_config["ulysses_size"]. That only matches the default contiguous layout, so a valid non-defaultorderwill fail this test even if CFG routing is correct. Please derive the expectation frompipeline.model_config.visual_gen_mapping.cfg_rankinstead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan_i2v.py` around lines 287 - 289, The test wrongly assumes contiguous CFG ordering by computing expected_cfg_group as rank // cfg_config["ulysses_size"]; instead use the actual CFG ordering from pipeline.model_config.visual_gen_mapping.cfg_rank: find the position of the current rank in pipeline.model_config.visual_gen_mapping.cfg_rank (e.g. position = pipeline.model_config.visual_gen_mapping.cfg_rank.index(rank)) and then compute expected_cfg_group = position // cfg_config["ulysses_size"], replacing the old division-based expression so non-default orderings pass.tensorrt_llm/_torch/visual_gen/config.py (1)
454-459: Consider using a forward reference orTYPE_CHECKINGimport forvisual_gen_mapping.The field is typed as
Optional[Any]with a comment about lazy import. Using a string forward reference orTYPE_CHECKINGimport would provide better type hints while avoiding circular imports.♻️ Optional: Add type hint via TYPE_CHECKING
At the top of the file, add:
if TYPE_CHECKING: from tensorrt_llm._torch.visual_gen.mapping import VisualGenMappingThen update the field:
- visual_gen_mapping: Optional[Any] = None # VisualGenMapping (lazy import) + visual_gen_mapping: Optional["VisualGenMapping"] = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/config.py` around lines 454 - 459, The visual_gen_mapping dataclass field is currently typed as Optional[Any] with a lazy-import comment; replace that with a proper forward reference or TYPE_CHECKING import to improve type hints without causing circular imports: add a TYPE_CHECKING block at the top importing VisualGenMapping from tensorrt_llm._torch.visual_gen.mapping (or use the string forward reference "VisualGenMapping") and update the field declaration visual_gen_mapping: Optional["VisualGenMapping"] = None so setup_visual_gen_mapping and the field get correct typing while retaining lazy import behavior.tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py (1)
254-268: Consider usingtorch.device("cuda", rank)for explicit device assignment.The current code uses
torch.device(f"cuda:{rank}")which works, but the numeric formtorch.device("cuda", rank)is slightly more explicit and avoids string formatting.♻️ Optional improvement
- device = torch.device(f"cuda:{rank}") + device = torch.device("cuda", rank)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py` around lines 254 - 268, Replace the string-style device construction with the explicit numeric form: change the torch.device creation used to assign device (currently torch.device(f"cuda:{rank}") stored in variable device) to torch.device("cuda", rank); ensure the rest of the test (tensor, tensor2, dist.all_reduce calls using vgm.tp_group_pg and vgm.ulysses_group and assertions referencing vgm.tp_size and vgm.ulysses_size) continues to use that device variable unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/mapping.py`:
- Around line 1-9: This file is missing the required NVIDIA copyright header and
SPDX license; add the standard NVIDIA Apache-2.0 header (including
SPDX-License-Identifier: Apache-2.0) with the year of latest modification at the
top of tensorrt_llm/_torch/visual_gen/mapping.py above the module docstring so
all definitions such as VisualGenMapping, DeviceMeshTopologyImpl and build_mesh
are covered by the license header.
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py`:
- Around line 985-990: The multi-modal CFG branch selection uses vgm.cfg_rank
and then reconstructs cond_v/uncond_v from fixed gather slots, which only works
when CFG is the outermost dimension; update the logic in the block that computes
do_cfg_parallel_mm/cfg_group (using symbols vgm, cfg_size, ulysses_size,
do_cfg_parallel_mm, cfg_group) to either (a) enforce that this path only
supports binary cond/uncond by asserting cfg_size == 2 (and fail fast with a
clear message), or (b) properly compute the local cond/uncond indices from
VisualGenMapping.order (i.e., use vgm mapping offsets instead of hard-coded
slots) so cond_v/uncond_v reconstruction (the code that rebuilds cond_v/uncond_v
later) uses the correct per-rank slots; prefer (a) if you only support
cond/uncond now, otherwise implement (b) to generalize.
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 770-783: The code assumes model_config is non-null when accessing
model_config.visual_gen_mapping; guard that access by first checking
model_config (or using getattr/model_config and default None) and set vgm =
model_config.visual_gen_mapping only if model_config is not None, then compute
ulysses_size = vgm.ulysses_size if vgm else 1; ensure primary_heads uses
num_attention_heads/audio_num_attention_heads as before and keep the
divisibility check and assignments to self.use_ulysses, self.ulysses_size,
self.ulysses_pg, and self.ulysses_rank unchanged but using the guarded vgm to
avoid an AttributeError.
In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 70-71: The code uses config.visual_gen_mapping.ulysses_size to
shard backend_num_heads even when no process group/mesh exists, causing
UlyssesAttention (which falls back to world_size=1 when process_group=None) to
disagree on head counts; change the logic so that when build_mesh()/the
mesh/group is not present you reject ulysses_size>1 (raise a clear error) or
derive the sharding factor from the actual process group instead of
vgm.ulysses_size; specifically update the branches around vgm =
config.visual_gen_mapping and places wrapping UlyssesAttention (and the
analogous block at lines ~138-145) to either (a) check for an actual process
group/mesh before using ulysses_size and raise if ulysses_size>1 and no group,
or (b) query the group.world_size and compute shard = min(ulysses_size,
group.world_size) and use that for backend_num_heads and when constructing the
wrapper with process_group.
In `@tensorrt_llm/_torch/visual_gen/pipeline_loader.py`:
- Around line 109-116: PipelineLoader is not passing the new ring axis into
VisualGenMapping, so VisualGenMapping's product check fails when
args.parallel.dit_ring_size > 1; update the instantiation of VisualGenMapping in
pipeline_loader.py to forward dit_ring_size (e.g., pass
ring_size=self.args.parallel.dit_ring_size or the correct parameter name
expected by VisualGenMapping) alongside cfg_size, tp_size, ulysses_size and
order, and ensure any downstream product/check logic uses that ring_size when
computing world_size.
- Around line 105-120: When args is None we must create a
single-rank/single-world VisualGenMapping instead of using the global dist world
size; update _setup_visual_gen_mapping so the else branch passes ws=1 and rk=0
to VisualGenMapping (so VisualGenMapping(...) is called with world-size 1 and
rank 0), then assign config.visual_gen_mapping = vgm and config.mapping =
vgm.to_llm_mapping() exactly as before; this ensures PipelineLoader(...,
args=None) always produces a no-visual-parallelism mapping even when dist is
initialized.
In `@tensorrt_llm/_torch/visual_gen/pipeline.py`:
- Around line 609-617: The CFG all-gather assumes exactly two CFG ranks but
doesn't validate that; add an assertion or explicit check that cfg_size == 2
before using gather_list[0] and gather_list[1], or document and raise a clear
error if cfg_size != 2. Locate the block that calls dist.all_gather with
gather_list (variables noise_pred_local, gather_list, cfg_size, cfg_pg) and
either assert cfg_size == 2 (or raise ValueError with explanatory message) or
handle more than two entries by selecting the correct conditional/unconditional
indices via a named mapping; ensure noise_cond and noise_uncond assignments and
subsequent guidance_scale computation (noise_pred = noise_uncond +
guidance_scale * (noise_cond - noise_uncond)) only run when the check passes.
---
Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/__init__.py`:
- Line 24: Replace the direct symbol import in the package root by importing the
module and re-exporting the class via the module namespace: change the current
"from .mapping import VisualGenMapping" to import the module (e.g., "from .
import mapping") and then expose mapping.VisualGenMapping in this package's
__all__ or by assigning VisualGenMapping = mapping.VisualGenMapping; update the
module-level export so the package root still provides VisualGenMapping but
preserves namespace-import style.
In `@tensorrt_llm/_torch/visual_gen/config.py`:
- Around line 454-459: The visual_gen_mapping dataclass field is currently typed
as Optional[Any] with a lazy-import comment; replace that with a proper forward
reference or TYPE_CHECKING import to improve type hints without causing circular
imports: add a TYPE_CHECKING block at the top importing VisualGenMapping from
tensorrt_llm._torch.visual_gen.mapping (or use the string forward reference
"VisualGenMapping") and update the field declaration visual_gen_mapping:
Optional["VisualGenMapping"] = None so setup_visual_gen_mapping and the field
get correct typing while retaining lazy import behavior.
In `@tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py`:
- Around line 254-268: Replace the string-style device construction with the
explicit numeric form: change the torch.device creation used to assign device
(currently torch.device(f"cuda:{rank}") stored in variable device) to
torch.device("cuda", rank); ensure the rest of the test (tensor, tensor2,
dist.all_reduce calls using vgm.tp_group_pg and vgm.ulysses_group and assertions
referencing vgm.tp_size and vgm.ulysses_size) continues to use that device
variable unchanged.
In `@tests/unittest/_torch/visual_gen/test_wan_i2v.py`:
- Around line 287-289: The test wrongly assumes contiguous CFG ordering by
computing expected_cfg_group as rank // cfg_config["ulysses_size"]; instead use
the actual CFG ordering from pipeline.model_config.visual_gen_mapping.cfg_rank:
find the position of the current rank in
pipeline.model_config.visual_gen_mapping.cfg_rank (e.g. position =
pipeline.model_config.visual_gen_mapping.cfg_rank.index(rank)) and then compute
expected_cfg_group = position // cfg_config["ulysses_size"], replacing the old
division-based expression so non-default orderings pass.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0cae41a6-88a0-4c75-80ff-73fe2a55a8f0
📒 Files selected for processing (21)
tensorrt_llm/_torch/visual_gen/__init__.pytensorrt_llm/_torch/visual_gen/config.pytensorrt_llm/_torch/visual_gen/executor.pytensorrt_llm/_torch/visual_gen/mapping.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.pytensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.pytensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.pytensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.pytensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytensorrt_llm/_torch/visual_gen/parallelism.pytensorrt_llm/_torch/visual_gen/pipeline.pytensorrt_llm/_torch/visual_gen/pipeline_loader.pytests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.pytests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.pytests/unittest/_torch/visual_gen/test_flux_pipeline.pytests/unittest/_torch/visual_gen/test_model_loader.pytests/unittest/_torch/visual_gen/test_wan.pytests/unittest/_torch/visual_gen/test_wan_i2v.py
💤 Files with no reviewable changes (2)
- tensorrt_llm/_torch/visual_gen/executor.py
- tensorrt_llm/_torch/visual_gen/parallelism.py
|
/bot run --disable-fail-fast |
|
PR_Github #40166 [ run ] triggered by Bot. Commit: |
|
PR_Github #40166 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41709 [ run ] triggered by Bot. Commit: |
|
PR_Github #41709 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
1 similar comment
|
/bot run --disable-fail-fast |
|
PR_Github #41835 [ run ] triggered by Bot. Commit: |
|
PR_Github #41835 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41884 [ run ] triggered by Bot. Commit: |
|
PR_Github #41884 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41937 [ run ] triggered by Bot. Commit: |
|
PR_Github #41937 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41991 [ run ] triggered by Bot. Commit: |
|
PR_Github #41991 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42142 [ run ] triggered by Bot. Commit: |
|
PR_Github #42142 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42190 [ run ] triggered by Bot. Commit: |
|
PR_Github #42190 [ run ] completed with state
|
75d6c42 to
19579b2
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #42360 [ run ] triggered by Bot. Commit: |
|
PR_Github #42360 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42547 [ run ] triggered by Bot. Commit: |
|
PR_Github #42547 [ run ] completed with state |
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
19579b2 to
5c9406e
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #42710 [ run ] triggered by Bot. Commit: |
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
5c9406e to
13c686c
Compare
|
PR_Github #42710 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42746 [ run ] triggered by Bot. Commit: |
|
PR_Github #42746 [ run ] completed with state |
|
@chang-l / @zhenhuaw-me could you please merge? |
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Summary by CodeRabbit
New Features
VisualGenMappingfor unified management of distributed parallelism strategies (CFG, tensor parallelism, ring, and Ulysses sequence parallelism) in visual generation workloads.Refactor
ParallelConfigusage withVisualGenMappingfor improved separation of concerns.Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.