Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[TRTLLM-11532][refactor] Unify VisualGen parallelism#12509

Merged
chang-l merged 5 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
NVShreyas:user/shreyasm/mapping-refactorNVShreyas/TensorRT-LLM:user/shreyasm/mapping-refactorCopy head branch name to clipboard
Apr 12, 2026
Merged

[TRTLLM-11532][refactor] Unify VisualGen parallelism#12509
chang-l merged 5 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
NVShreyas:user/shreyasm/mapping-refactorNVShreyas/TensorRT-LLM:user/shreyasm/mapping-refactorCopy head branch name to clipboard

Conversation

@NVShreyas
Copy link
Copy Markdown
Collaborator

@NVShreyas NVShreyas commented Mar 24, 2026

  • Introduces VisualGenMapping, a single class that manages all VisualGen parallelism axes (CFG, TP, Ring, Ulysses) through one PyTorch DeviceMesh, replacing the previous split between tensorrt_llm.mapping.Mapping (for TP) and ad-hoc dist.new_group() calls in ParallelConfig (for CFG/Ulysses).
  • Removes parallelism.py and the ParallelConfig field from DiffusionModelConfig, establishing VisualGenMapping as the single source of truth for all rank decomposition, process groups, and the LLM Mapping bridge (to_llm_mapping()).
  • Dimension ordering is configurable via a string (e.g. "cfg-tp-ring-ulysses"), making it easy to experiment with different rank-contiguity layouts without code changes.

Summary by CodeRabbit

  • New Features

    • Introduced VisualGenMapping for unified management of distributed parallelism strategies (CFG, tensor parallelism, ring, and Ulysses sequence parallelism) in visual generation workloads.
  • Refactor

    • Reorganized parallelism configuration architecture: moved VAE parallelism settings into model config and replaced direct ParallelConfig usage with VisualGenMapping for improved separation of concerns.
    • Simplified parallel sequence setup by removing helper utilities and integrating configuration directly into model initialization.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@NVShreyas NVShreyas requested a review from a team as a code owner March 24, 2026 17:22
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

This refactoring centralizes visual generation parallelism configuration into a new VisualGenMapping class. It removes parallelism logic from ParallelConfig and VisualGenArgs.to_mapping(), replaces the parallel field in DiffusionModelConfig with visual_gen_mapping, and updates all model implementations to derive CFG, Ulysses, and tensor parallelism settings from the new mapping instead of distributed utility functions.

Changes

Cohort / File(s) Summary
Core mapping and configuration
tensorrt_llm/_torch/visual_gen/__init__.py, tensorrt_llm/_torch/visual_gen/config.py, tensorrt_llm/_torch/visual_gen/mapping.py
Added new VisualGenMapping class (187 lines) for unified multi-dimensional device mesh management across CFG/TP/ring/Ulysses axes. Removed to_mapping() methods and parallelism validation logic from ParallelConfig and VisualGenArgs. Replaced parallel: ParallelConfig field in DiffusionModelConfig with visual_gen_mapping: Optional[Any], and promoted VAE parallelism settings (enable_parallel_vae, parallel_vae_split_dim) directly into DiffusionModelConfig.
Pipeline and executor core
tensorrt_llm/_torch/visual_gen/executor.py, tensorrt_llm/_torch/visual_gen/pipeline.py, tensorrt_llm/_torch/visual_gen/pipeline_loader.py
Removed world-size validation check from run_diffusion_worker. Updated pipeline.py to read CFG/Ulysses settings from visual_gen_mapping instead of parallel config, and refactored _setup_cfg_config() and _denoise_step_cfg_parallel() to use mapping-derived process groups. Modified pipeline_loader.py to construct VisualGenMapping using distributed context and populate config.visual_gen_mapping before pipeline creation.
Parallelism utilities
tensorrt_llm/_torch/visual_gen/parallelism.py
Removed entire setup_sequence_parallelism() utility function (100 lines) that previously handled distributed sequence parallelism setup.
Flux model components
tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py, tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py, tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py, tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py
Updated pipeline CFG validation and transformer Ulysses initialization to use visual_gen_mapping fields instead of parallel config. Removed setup_sequence_parallelism calls and replaced with direct reads from visual_gen_mapping. Added num_attention_heads divisibility validation against ulysses_size.
LTX2 model components
tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py, tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
Updated CFG and Ulysses configuration derivation to use visual_gen_mapping. Removed setup_sequence_parallelism calls. Added head divisibility validation for Ulysses parallelism in LTXModel.
WAN transformer
tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
Replaced setup_sequence_parallelism call with direct reads from visual_gen_mapping for Ulysses settings. Changed tensor parallelism validation from parallel.dit_tp_size to visual_gen_mapping.tp_size.
Attention modules
tensorrt_llm/_torch/visual_gen/modules/attention.py
Updated Ulysses configuration source from config.parallel.dit_ulysses_size to config.visual_gen_mapping.ulysses_size, with fallback to 1 when mapping is absent.
Test updates
tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py, tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py, tests/unittest/_torch/visual_gen/test_flux_pipeline.py, tests/unittest/_torch/visual_gen/test_model_loader.py, tests/unittest/_torch/visual_gen/test_wan.py, tests/unittest/_torch/visual_gen/test_wan_i2v.py
Updated test configurations to use VisualGenMapping instead of ParallelConfig. Added comprehensive new test module (280 lines) validating VisualGenMapping across single-GPU and multi-GPU distributed scenarios. Updated assertion checks to verify configuration via visual_gen_mapping fields.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description clearly explains the problem and solution. However, it lacks the required template structure with explicit 'Description', 'Test Coverage', and 'PR Checklist' sections properly separated. Expand description with dedicated sections: clearly separate the problem statement, solution approach, and list specific test cases (test_visual_gen_mapping.py, test_flux_ulysses.py, etc.) that validate the new unified mapping implementation.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: introducing a unified VisualGenMapping class to consolidate VisualGen parallelism management, replacing the previous split architecture.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (4)
tensorrt_llm/_torch/visual_gen/__init__.py (1)

24-24: Use a module import for the new public export.

This adds another direct symbol import into the package root. Please import .mapping and re-export mapping.VisualGenMapping instead so this file stays aligned with the repo's namespace-import rule.

As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/__init__.py` at line 24, Replace the direct
symbol import in the package root by importing the module and re-exporting the
class via the module namespace: change the current "from .mapping import
VisualGenMapping" to import the module (e.g., "from . import mapping") and then
expose mapping.VisualGenMapping in this package's __all__ or by assigning
VisualGenMapping = mapping.VisualGenMapping; update the module-level export so
the package root still provides VisualGenMapping but preserves namespace-import
style.
tests/unittest/_torch/visual_gen/test_wan_i2v.py (1)

287-289: This worker still assumes the default CFG/Ulysses rank ordering.

The assertion now reads cfg_size from visual_gen_mapping, but Line 310 still derives expected_cfg_group as rank // cfg_config["ulysses_size"]. That only matches the default contiguous layout, so a valid non-default order will fail this test even if CFG routing is correct. Please derive the expectation from pipeline.model_config.visual_gen_mapping.cfg_rank instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/test_wan_i2v.py` around lines 287 - 289, The
test wrongly assumes contiguous CFG ordering by computing expected_cfg_group as
rank // cfg_config["ulysses_size"]; instead use the actual CFG ordering from
pipeline.model_config.visual_gen_mapping.cfg_rank: find the position of the
current rank in pipeline.model_config.visual_gen_mapping.cfg_rank (e.g. position
= pipeline.model_config.visual_gen_mapping.cfg_rank.index(rank)) and then
compute expected_cfg_group = position // cfg_config["ulysses_size"], replacing
the old division-based expression so non-default orderings pass.
tensorrt_llm/_torch/visual_gen/config.py (1)

454-459: Consider using a forward reference or TYPE_CHECKING import for visual_gen_mapping.

The field is typed as Optional[Any] with a comment about lazy import. Using a string forward reference or TYPE_CHECKING import would provide better type hints while avoiding circular imports.

♻️ Optional: Add type hint via TYPE_CHECKING

At the top of the file, add:

if TYPE_CHECKING:
    from tensorrt_llm._torch.visual_gen.mapping import VisualGenMapping

Then update the field:

-    visual_gen_mapping: Optional[Any] = None  # VisualGenMapping (lazy import)
+    visual_gen_mapping: Optional["VisualGenMapping"] = None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/config.py` around lines 454 - 459, The
visual_gen_mapping dataclass field is currently typed as Optional[Any] with a
lazy-import comment; replace that with a proper forward reference or
TYPE_CHECKING import to improve type hints without causing circular imports: add
a TYPE_CHECKING block at the top importing VisualGenMapping from
tensorrt_llm._torch.visual_gen.mapping (or use the string forward reference
"VisualGenMapping") and update the field declaration visual_gen_mapping:
Optional["VisualGenMapping"] = None so setup_visual_gen_mapping and the field
get correct typing while retaining lazy import behavior.
tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py (1)

254-268: Consider using torch.device("cuda", rank) for explicit device assignment.

The current code uses torch.device(f"cuda:{rank}") which works, but the numeric form torch.device("cuda", rank) is slightly more explicit and avoids string formatting.

♻️ Optional improvement
-    device = torch.device(f"cuda:{rank}")
+    device = torch.device("cuda", rank)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py` around
lines 254 - 268, Replace the string-style device construction with the explicit
numeric form: change the torch.device creation used to assign device (currently
torch.device(f"cuda:{rank}") stored in variable device) to torch.device("cuda",
rank); ensure the rest of the test (tensor, tensor2, dist.all_reduce calls using
vgm.tp_group_pg and vgm.ulysses_group and assertions referencing vgm.tp_size and
vgm.ulysses_size) continues to use that device variable unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/mapping.py`:
- Around line 1-9: This file is missing the required NVIDIA copyright header and
SPDX license; add the standard NVIDIA Apache-2.0 header (including
SPDX-License-Identifier: Apache-2.0) with the year of latest modification at the
top of tensorrt_llm/_torch/visual_gen/mapping.py above the module docstring so
all definitions such as VisualGenMapping, DeviceMeshTopologyImpl and build_mesh
are covered by the license header.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py`:
- Around line 985-990: The multi-modal CFG branch selection uses vgm.cfg_rank
and then reconstructs cond_v/uncond_v from fixed gather slots, which only works
when CFG is the outermost dimension; update the logic in the block that computes
do_cfg_parallel_mm/cfg_group (using symbols vgm, cfg_size, ulysses_size,
do_cfg_parallel_mm, cfg_group) to either (a) enforce that this path only
supports binary cond/uncond by asserting cfg_size == 2 (and fail fast with a
clear message), or (b) properly compute the local cond/uncond indices from
VisualGenMapping.order (i.e., use vgm mapping offsets instead of hard-coded
slots) so cond_v/uncond_v reconstruction (the code that rebuilds cond_v/uncond_v
later) uses the correct per-rank slots; prefer (a) if you only support
cond/uncond now, otherwise implement (b) to generalize.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 770-783: The code assumes model_config is non-null when accessing
model_config.visual_gen_mapping; guard that access by first checking
model_config (or using getattr/model_config and default None) and set vgm =
model_config.visual_gen_mapping only if model_config is not None, then compute
ulysses_size = vgm.ulysses_size if vgm else 1; ensure primary_heads uses
num_attention_heads/audio_num_attention_heads as before and keep the
divisibility check and assignments to self.use_ulysses, self.ulysses_size,
self.ulysses_pg, and self.ulysses_rank unchanged but using the guarded vgm to
avoid an AttributeError.

In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Around line 70-71: The code uses config.visual_gen_mapping.ulysses_size to
shard backend_num_heads even when no process group/mesh exists, causing
UlyssesAttention (which falls back to world_size=1 when process_group=None) to
disagree on head counts; change the logic so that when build_mesh()/the
mesh/group is not present you reject ulysses_size>1 (raise a clear error) or
derive the sharding factor from the actual process group instead of
vgm.ulysses_size; specifically update the branches around vgm =
config.visual_gen_mapping and places wrapping UlyssesAttention (and the
analogous block at lines ~138-145) to either (a) check for an actual process
group/mesh before using ulysses_size and raise if ulysses_size>1 and no group,
or (b) query the group.world_size and compute shard = min(ulysses_size,
group.world_size) and use that for backend_num_heads and when constructing the
wrapper with process_group.

In `@tensorrt_llm/_torch/visual_gen/pipeline_loader.py`:
- Around line 109-116: PipelineLoader is not passing the new ring axis into
VisualGenMapping, so VisualGenMapping's product check fails when
args.parallel.dit_ring_size > 1; update the instantiation of VisualGenMapping in
pipeline_loader.py to forward dit_ring_size (e.g., pass
ring_size=self.args.parallel.dit_ring_size or the correct parameter name
expected by VisualGenMapping) alongside cfg_size, tp_size, ulysses_size and
order, and ensure any downstream product/check logic uses that ring_size when
computing world_size.
- Around line 105-120: When args is None we must create a
single-rank/single-world VisualGenMapping instead of using the global dist world
size; update _setup_visual_gen_mapping so the else branch passes ws=1 and rk=0
to VisualGenMapping (so VisualGenMapping(...) is called with world-size 1 and
rank 0), then assign config.visual_gen_mapping = vgm and config.mapping =
vgm.to_llm_mapping() exactly as before; this ensures PipelineLoader(...,
args=None) always produces a no-visual-parallelism mapping even when dist is
initialized.

In `@tensorrt_llm/_torch/visual_gen/pipeline.py`:
- Around line 609-617: The CFG all-gather assumes exactly two CFG ranks but
doesn't validate that; add an assertion or explicit check that cfg_size == 2
before using gather_list[0] and gather_list[1], or document and raise a clear
error if cfg_size != 2. Locate the block that calls dist.all_gather with
gather_list (variables noise_pred_local, gather_list, cfg_size, cfg_pg) and
either assert cfg_size == 2 (or raise ValueError with explanatory message) or
handle more than two entries by selecting the correct conditional/unconditional
indices via a named mapping; ensure noise_cond and noise_uncond assignments and
subsequent guidance_scale computation (noise_pred = noise_uncond +
guidance_scale * (noise_cond - noise_uncond)) only run when the check passes.

---

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/__init__.py`:
- Line 24: Replace the direct symbol import in the package root by importing the
module and re-exporting the class via the module namespace: change the current
"from .mapping import VisualGenMapping" to import the module (e.g., "from .
import mapping") and then expose mapping.VisualGenMapping in this package's
__all__ or by assigning VisualGenMapping = mapping.VisualGenMapping; update the
module-level export so the package root still provides VisualGenMapping but
preserves namespace-import style.

In `@tensorrt_llm/_torch/visual_gen/config.py`:
- Around line 454-459: The visual_gen_mapping dataclass field is currently typed
as Optional[Any] with a lazy-import comment; replace that with a proper forward
reference or TYPE_CHECKING import to improve type hints without causing circular
imports: add a TYPE_CHECKING block at the top importing VisualGenMapping from
tensorrt_llm._torch.visual_gen.mapping (or use the string forward reference
"VisualGenMapping") and update the field declaration visual_gen_mapping:
Optional["VisualGenMapping"] = None so setup_visual_gen_mapping and the field
get correct typing while retaining lazy import behavior.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py`:
- Around line 254-268: Replace the string-style device construction with the
explicit numeric form: change the torch.device creation used to assign device
(currently torch.device(f"cuda:{rank}") stored in variable device) to
torch.device("cuda", rank); ensure the rest of the test (tensor, tensor2,
dist.all_reduce calls using vgm.tp_group_pg and vgm.ulysses_group and assertions
referencing vgm.tp_size and vgm.ulysses_size) continues to use that device
variable unchanged.

In `@tests/unittest/_torch/visual_gen/test_wan_i2v.py`:
- Around line 287-289: The test wrongly assumes contiguous CFG ordering by
computing expected_cfg_group as rank // cfg_config["ulysses_size"]; instead use
the actual CFG ordering from pipeline.model_config.visual_gen_mapping.cfg_rank:
find the position of the current rank in
pipeline.model_config.visual_gen_mapping.cfg_rank (e.g. position =
pipeline.model_config.visual_gen_mapping.cfg_rank.index(rank)) and then compute
expected_cfg_group = position // cfg_config["ulysses_size"], replacing the old
division-based expression so non-default orderings pass.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0cae41a6-88a0-4c75-80ff-73fe2a55a8f0

📥 Commits

Reviewing files that changed from the base of the PR and between d8eb3a6 and 2786477.

📒 Files selected for processing (21)
  • tensorrt_llm/_torch/visual_gen/__init__.py
  • tensorrt_llm/_torch/visual_gen/config.py
  • tensorrt_llm/_torch/visual_gen/executor.py
  • tensorrt_llm/_torch/visual_gen/mapping.py
  • tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py
  • tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
  • tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py
  • tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux2.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
  • tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
  • tensorrt_llm/_torch/visual_gen/modules/attention.py
  • tensorrt_llm/_torch/visual_gen/parallelism.py
  • tensorrt_llm/_torch/visual_gen/pipeline.py
  • tensorrt_llm/_torch/visual_gen/pipeline_loader.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_flux_ulysses.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_visual_gen_mapping.py
  • tests/unittest/_torch/visual_gen/test_flux_pipeline.py
  • tests/unittest/_torch/visual_gen/test_model_loader.py
  • tests/unittest/_torch/visual_gen/test_wan.py
  • tests/unittest/_torch/visual_gen/test_wan_i2v.py
💤 Files with no reviewable changes (2)
  • tensorrt_llm/_torch/visual_gen/executor.py
  • tensorrt_llm/_torch/visual_gen/parallelism.py

Comment thread tensorrt_llm/_torch/visual_gen/mapping.py
Comment thread tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py
Comment thread tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
Comment thread tensorrt_llm/_torch/visual_gen/modules/attention.py
Comment thread tensorrt_llm/_torch/visual_gen/pipeline_loader.py
Comment thread tensorrt_llm/_torch/visual_gen/pipeline_loader.py
Comment thread tensorrt_llm/_torch/visual_gen/pipeline.py
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40166 [ run ] triggered by Bot. Commit: bf51784 Link to invocation

Comment thread tensorrt_llm/_torch/visual_gen/config.py
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40166 [ run ] completed with state SUCCESS. Commit: bf51784
/LLM/main/L0_MergeRequest_PR pipeline #31311 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@zhenhuaw-me zhenhuaw-me requested a review from a team March 25, 2026 03:17
Comment thread tensorrt_llm/_torch/visual_gen/mapping.py Outdated
Comment thread tensorrt_llm/_torch/visual_gen/mapping.py Outdated
Comment thread tensorrt_llm/_torch/visual_gen/config.py
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41709 [ run ] triggered by Bot. Commit: d5f8ae2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41709 [ run ] completed with state SUCCESS. Commit: d5f8ae2
/LLM/main/L0_MergeRequest_PR pipeline #32611 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

1 similar comment
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41835 [ run ] triggered by Bot. Commit: 75d6c42 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41835 [ run ] completed with state FAILURE. Commit: 75d6c42
/LLM/main/L0_MergeRequest_PR pipeline #32707 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41884 [ run ] triggered by Bot. Commit: 75d6c42 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41884 [ run ] completed with state SUCCESS. Commit: 75d6c42
/LLM/main/L0_MergeRequest_PR pipeline #32749 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41937 [ run ] triggered by Bot. Commit: 75d6c42 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41937 [ run ] completed with state SUCCESS. Commit: 75d6c42
/LLM/main/L0_MergeRequest_PR pipeline #32794 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41991 [ run ] triggered by Bot. Commit: 75d6c42 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41991 [ run ] completed with state SUCCESS. Commit: 75d6c42
/LLM/main/L0_MergeRequest_PR pipeline #32842 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42142 [ run ] triggered by Bot. Commit: 75d6c42 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42142 [ run ] completed with state SUCCESS. Commit: 75d6c42
/LLM/main/L0_MergeRequest_PR pipeline #32976 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42190 [ run ] triggered by Bot. Commit: 75d6c42 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42190 [ run ] completed with state SUCCESS. Commit: 75d6c42
/LLM/main/L0_MergeRequest_PR pipeline #33014 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas NVShreyas force-pushed the user/shreyasm/mapping-refactor branch from 75d6c42 to 19579b2 Compare April 8, 2026 15:47
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42360 [ run ] triggered by Bot. Commit: 19579b2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42360 [ run ] completed with state SUCCESS. Commit: 19579b2
/LLM/main/L0_MergeRequest_PR pipeline #33145 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42547 [ run ] triggered by Bot. Commit: 19579b2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42547 [ run ] completed with state SUCCESS. Commit: 19579b2
/LLM/main/L0_MergeRequest_PR pipeline #33284 completed with status: 'SUCCESS'

CI Report

Link to invocation

Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
@NVShreyas NVShreyas force-pushed the user/shreyasm/mapping-refactor branch from 19579b2 to 5c9406e Compare April 10, 2026 15:46
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42710 [ run ] triggered by Bot. Commit: 5c9406e Link to invocation

Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
@NVShreyas NVShreyas force-pushed the user/shreyasm/mapping-refactor branch from 5c9406e to 13c686c Compare April 10, 2026 16:01
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42710 [ run ] completed with state SUCCESS. Commit: 5c9406e
/LLM/main/L0_MergeRequest_PR pipeline #33404 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42746 [ run ] triggered by Bot. Commit: 13c686c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42746 [ run ] completed with state SUCCESS. Commit: 13c686c
/LLM/main/L0_MergeRequest_PR pipeline #33424 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

@chang-l / @zhenhuaw-me could you please merge?

@chang-l chang-l merged commit 5c73ac0 into NVIDIA:main Apr 12, 2026
5 checks passed
bmarimuthu-nv pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Apr 16, 2026
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.