Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[TRTLLM-11622][fix] fix parallel WAN vae when return_dict=True#12460

Merged
chang-l merged 4 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
NVShreyas:user/shreyasm/par-vae-fixNVShreyas/TensorRT-LLM:user/shreyasm/par-vae-fixCopy head branch name to clipboard
Mar 24, 2026
Merged

[TRTLLM-11622][fix] fix parallel WAN vae when return_dict=True#12460
chang-l merged 4 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
NVShreyas:user/shreyasm/par-vae-fixNVShreyas/TensorRT-LLM:user/shreyasm/par-vae-fixCopy head branch name to clipboard

Conversation

@NVShreyas
Copy link
Copy Markdown
Collaborator

@NVShreyas NVShreyas commented Mar 23, 2026

Summary by CodeRabbit

  • Refactor

    • Updated encoder implementation to optimize parameter gathering across distributed partitions and support conditional output formats.
    • Updated decoder implementation to optimize sample gathering across partitions and support conditional output formats.
  • Tests

    • Updated parallel VAE multi-GPU tests to use new API initialization flow.
    • Simplified test logic by removing adjacency group dependencies.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@NVShreyas NVShreyas requested a review from a team as a code owner March 23, 2026 15:39
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39954 [ run ] triggered by Bot. Commit: 8d12344 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39954 [ run ] completed with state SUCCESS. Commit: 8d12344
/LLM/main/L0_MergeRequest_PR pipeline #31119 completed with status: 'SUCCESS'

CI Report

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

@chang-l / @zhenhuaw-me could you please review?

Comment thread tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py Outdated
Comment thread tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py
Comment thread tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py Outdated
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 23, 2026

📝 Walkthrough

Walkthrough

Refactored ParallelVAE_Wan encoder and decoder methods to internally manage gathering of parameters and samples across partitions, replacing the external adapter pattern. Adjusted test suite to construct and invoke the parallel VAE directly using the new API.

Changes

Cohort / File(s) Summary
Parallel VAE Implementation
tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py
Refactored _encode_impl and _decode_impl to call backend with return_dict=False, manually gather outputs across partitions, wrap results in DiagonalGaussianDistribution or DecoderOutput, and return conditional on return_dict flag. Removed explicit return type annotations.
Test Updates
tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py
Replaced WanParallelVAEAdapter pattern with direct ParallelVAE_Wan construction using process groups and make_spec(). Removed adjacency group logic and height-splitting decode test. Updated test calls to invoke methods on ParallelVAE_Wan objects.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is incomplete, containing only template scaffolding without actual implementation details, test coverage information, or explanation of the fix. Fill in the Description section explaining the issue and solution, and provide relevant test coverage details that safeguard these changes.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main fix: addressing parallel WAN VAE behavior when return_dict=True, matching the core changes in the implementation.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py (1)

131-135: Exercise the default decode path here.

This still asserts through parallel.decode(..., return_dict=False)[0], so the new DecoderOutput(sample=...) branch never runs. Please switch this check to parallel.decode(latent).sample or add a companion assertion for the default path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py` around lines
131 - 135, The test currently only exercises the non-default return_dict=False
path by calling parallel.decode(latent, return_dict=False)[0], so the new
DecoderOutput(sample=...) branch isn't executed; update the assertion to call
the default decode path (e.g., use parallel.decode(latent).sample) or add an
additional assertion that calls parallel.decode(latent) and checks .sample on
the returned DecoderOutput. Locate the test that constructs ParallelVAE_Wan and
replace or supplement the existing call to parallel.decode(latent,
return_dict=False)[0] with a call to parallel.decode(latent).sample (or add a
companion assertion) to ensure the DecoderOutput branch is exercised.
tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py (1)

5-7: Prefer module imports for the new Diffusers types.

Please keep these symbols namespaced instead of adding more direct class imports here. It makes ownership clearer and matches the repo’s Python import rule.

As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions (e.g., use from package.subpackage import foo then foo.SomeClass() instead of from package.subpackage.foo import SomeClass)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py` around lines 5 -
7, The current imports bring specific classes into the local namespace
(AutoencoderKLOutput, WanAttentionBlock, WanCausalConv3d, DecoderOutput,
DiagonalGaussianDistribution); change them to import their modules instead
(e.g., import diffusers.models.autoencoders.autoencoder_kl as autoencoder_kl and
diffusers.models.autoencoders.autoencoder_kl_wan as autoencoder_kl_wan and
diffusers.models.autoencoders.vae as vae) and then update all usages in
parallel_vae.py to reference the classes via those modules
(autoencoder_kl.AutoencoderKLOutput, autoencoder_kl_wan.WanAttentionBlock,
autoencoder_kl_wan.WanCausalConv3d, vae.DecoderOutput,
vae.DiagonalGaussianDistribution) so the names remain namespaced per project
import rules.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py`:
- Around line 54-68: The encode/decode implementations drop caller kwargs by
hard-coding return_dict=False when calling vae_backend.encode/decode; update
_encode_impl and _decode_impl to forward all kwargs to vae_backend while
ensuring the backend still receives return_dict=False (e.g. copy kwargs, set or
override return_dict=False, then pass that dict into vae_backend.encode(...) and
vae_backend.decode(...)), keeping the existing use of _split_tensor,
_gather_tensor, DiagonalGaussianDistribution, AutoencoderKLOutput and
DecoderOutput unchanged.

---

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py`:
- Around line 5-7: The current imports bring specific classes into the local
namespace (AutoencoderKLOutput, WanAttentionBlock, WanCausalConv3d,
DecoderOutput, DiagonalGaussianDistribution); change them to import their
modules instead (e.g., import diffusers.models.autoencoders.autoencoder_kl as
autoencoder_kl and diffusers.models.autoencoders.autoencoder_kl_wan as
autoencoder_kl_wan and diffusers.models.autoencoders.vae as vae) and then update
all usages in parallel_vae.py to reference the classes via those modules
(autoencoder_kl.AutoencoderKLOutput, autoencoder_kl_wan.WanAttentionBlock,
autoencoder_kl_wan.WanCausalConv3d, vae.DecoderOutput,
vae.DiagonalGaussianDistribution) so the names remain namespaced per project
import rules.

In `@tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py`:
- Around line 131-135: The test currently only exercises the non-default
return_dict=False path by calling parallel.decode(latent, return_dict=False)[0],
so the new DecoderOutput(sample=...) branch isn't executed; update the assertion
to call the default decode path (e.g., use parallel.decode(latent).sample) or
add an additional assertion that calls parallel.decode(latent) and checks
.sample on the returned DecoderOutput. Locate the test that constructs
ParallelVAE_Wan and replace or supplement the existing call to
parallel.decode(latent, return_dict=False)[0] with a call to
parallel.decode(latent).sample (or add a companion assertion) to ensure the
DecoderOutput branch is exercised.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 60e32875-eb00-47eb-9e42-cbc50afc5546

📥 Commits

Reviewing files that changed from the base of the PR and between cc462ba and 65dd189.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py

Comment thread tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py
@NVShreyas NVShreyas force-pushed the user/shreyasm/par-vae-fix branch from 65dd189 to b4929dd Compare March 23, 2026 23:15
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39991 [ run ] triggered by Bot. Commit: b4929dd Link to invocation

Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
@NVShreyas NVShreyas force-pushed the user/shreyasm/par-vae-fix branch from b4929dd to 55994bd Compare March 24, 2026 01:33
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40004 [ kill ] triggered by Bot. Commit: 6ada7c6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40004 [ kill ] completed with state SUCCESS. Commit: 6ada7c6
Successfully killed previous jobs for commit 6ada7c6

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40012 [ run ] triggered by Bot. Commit: 6ada7c6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40012 [ run ] completed with state SUCCESS. Commit: 6ada7c6
/LLM/main/L0_MergeRequest_PR pipeline #31170 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40131 [ run ] triggered by Bot. Commit: 6ada7c6 Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

/bot help

@github-actions
Copy link
Copy Markdown

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40131 [ run ] completed with state SUCCESS. Commit: 6ada7c6
/LLM/main/L0_MergeRequest_PR pipeline #31278 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@NVShreyas
Copy link
Copy Markdown
Collaborator Author

@chang-l could you merge?

@chang-l chang-l merged commit 94175a8 into NVIDIA:main Mar 24, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.