[TRTLLM-11622][fix] fix parallel WAN vae when return_dict=True#12460
[TRTLLM-11622][fix] fix parallel WAN vae when return_dict=True#12460chang-l merged 4 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom NVShreyas:user/shreyasm/par-vae-fixNVShreyas/TensorRT-LLM:user/shreyasm/par-vae-fixCopy head branch name to clipboard
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #39954 [ run ] triggered by Bot. Commit: |
|
PR_Github #39954 [ run ] completed with state |
|
@chang-l / @zhenhuaw-me could you please review? |
📝 WalkthroughWalkthroughRefactored Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py (1)
131-135: Exercise the default decode path here.This still asserts through
parallel.decode(..., return_dict=False)[0], so the newDecoderOutput(sample=...)branch never runs. Please switch this check toparallel.decode(latent).sampleor add a companion assertion for the default path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py` around lines 131 - 135, The test currently only exercises the non-default return_dict=False path by calling parallel.decode(latent, return_dict=False)[0], so the new DecoderOutput(sample=...) branch isn't executed; update the assertion to call the default decode path (e.g., use parallel.decode(latent).sample) or add an additional assertion that calls parallel.decode(latent) and checks .sample on the returned DecoderOutput. Locate the test that constructs ParallelVAE_Wan and replace or supplement the existing call to parallel.decode(latent, return_dict=False)[0] with a call to parallel.decode(latent).sample (or add a companion assertion) to ensure the DecoderOutput branch is exercised.tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py (1)
5-7: Prefer module imports for the new Diffusers types.Please keep these symbols namespaced instead of adding more direct class imports here. It makes ownership clearer and matches the repo’s Python import rule.
As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions (e.g., use
from package.subpackage import foothenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py` around lines 5 - 7, The current imports bring specific classes into the local namespace (AutoencoderKLOutput, WanAttentionBlock, WanCausalConv3d, DecoderOutput, DiagonalGaussianDistribution); change them to import their modules instead (e.g., import diffusers.models.autoencoders.autoencoder_kl as autoencoder_kl and diffusers.models.autoencoders.autoencoder_kl_wan as autoencoder_kl_wan and diffusers.models.autoencoders.vae as vae) and then update all usages in parallel_vae.py to reference the classes via those modules (autoencoder_kl.AutoencoderKLOutput, autoencoder_kl_wan.WanAttentionBlock, autoencoder_kl_wan.WanCausalConv3d, vae.DecoderOutput, vae.DiagonalGaussianDistribution) so the names remain namespaced per project import rules.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py`:
- Around line 54-68: The encode/decode implementations drop caller kwargs by
hard-coding return_dict=False when calling vae_backend.encode/decode; update
_encode_impl and _decode_impl to forward all kwargs to vae_backend while
ensuring the backend still receives return_dict=False (e.g. copy kwargs, set or
override return_dict=False, then pass that dict into vae_backend.encode(...) and
vae_backend.decode(...)), keeping the existing use of _split_tensor,
_gather_tensor, DiagonalGaussianDistribution, AutoencoderKLOutput and
DecoderOutput unchanged.
---
Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py`:
- Around line 5-7: The current imports bring specific classes into the local
namespace (AutoencoderKLOutput, WanAttentionBlock, WanCausalConv3d,
DecoderOutput, DiagonalGaussianDistribution); change them to import their
modules instead (e.g., import diffusers.models.autoencoders.autoencoder_kl as
autoencoder_kl and diffusers.models.autoencoders.autoencoder_kl_wan as
autoencoder_kl_wan and diffusers.models.autoencoders.vae as vae) and then update
all usages in parallel_vae.py to reference the classes via those modules
(autoencoder_kl.AutoencoderKLOutput, autoencoder_kl_wan.WanAttentionBlock,
autoencoder_kl_wan.WanCausalConv3d, vae.DecoderOutput,
vae.DiagonalGaussianDistribution) so the names remain namespaced per project
import rules.
In `@tests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py`:
- Around line 131-135: The test currently only exercises the non-default
return_dict=False path by calling parallel.decode(latent, return_dict=False)[0],
so the new DecoderOutput(sample=...) branch isn't executed; update the assertion
to call the default decode path (e.g., use parallel.decode(latent).sample) or
add an additional assertion that calls parallel.decode(latent) and checks
.sample on the returned DecoderOutput. Locate the test that constructs
ParallelVAE_Wan and replace or supplement the existing call to
parallel.decode(latent, return_dict=False)[0] with a call to
parallel.decode(latent).sample (or add a companion assertion) to ensure the
DecoderOutput branch is exercised.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 60e32875-eb00-47eb-9e42-cbc50afc5546
📒 Files selected for processing (2)
tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.pytests/unittest/_torch/visual_gen/multi_gpu/test_parallel_vae.py
65dd189 to
b4929dd
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #39991 [ run ] triggered by Bot. Commit: |
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
b4929dd to
55994bd
Compare
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
|
/bot kill |
|
PR_Github #40004 [ kill ] triggered by Bot. Commit: |
|
PR_Github #40004 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #40012 [ run ] triggered by Bot. Commit: |
|
PR_Github #40012 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40131 [ run ] triggered by Bot. Commit: |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
PR_Github #40131 [ run ] completed with state |
|
@chang-l could you merge? |
Summary by CodeRabbit
Refactor
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.