Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[None][fix] Fix VLM guided decoding startup crash due to missing vocab_size_padded property#12284

Merged
pengbowang-nv merged 3 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
stefanpantic:users/stefan/fix-vllm-vocab-sizestefanpantic/TensorRT-LLM:users/stefan/fix-vllm-vocab-sizeCopy head branch name to clipboard
Apr 9, 2026
Merged

[None][fix] Fix VLM guided decoding startup crash due to missing vocab_size_padded property#12284
pengbowang-nv merged 3 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
stefanpantic:users/stefan/fix-vllm-vocab-sizestefanpantic/TensorRT-LLM:users/stefan/fix-vllm-vocab-sizeCopy head branch name to clipboard

Conversation

@stefanpantic
Copy link
Copy Markdown
Contributor

@stefanpantic stefanpantic commented Mar 17, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Added vocab_size_padded property to vision-language models for accessing padded vocabulary size information.
  • Tests

    • Added comprehensive test suite validating vocab_size_padded property across all vision-language models.
    • Added test for guided decoding with JSON schema validation support in chat completions.

Description

VLM wrapper classes (Qwen3VLModel, Qwen2VLModel, LlavaNextModel, etc.) extend
HuggingFace's PreTrainedModel rather than TRT-LLM's DecoderModelForCausalLM, so they
do not inherit the vocab_size_padded property. py_executor_creator.py unconditionally
reads model_engine.model.vocab_size_padded when initialising the GuidedDecoder causing
an AttributeError at server startup whenever guided decoding is configured with any VLM
model, regardless of GPU or request.

Fix: add vocab_size_padded as a @property to all 9 affected VLM wrapper classes,
delegating to self.llm.vocab_size_padded. This follows the same pattern as the existing
infer_max_seq_len delegation already present in every one of these classes.

Affected classes: Qwen3VLModelBase, Qwen2VLModelBase, LlavaNextModel, Gemma3VLM,
Phi4MMForCausalLM, Mistral3VLM, VilaModel, HCXVisionForCausalLM,
NemotronH_Nano_VL_V2.

Test Coverage

  • tests/unittest/_torch/models/test_vlm_vocab_size_padded.py (no GPU required):
    verifies vocab_size_padded is defined as a @property on each class, delegates to
    self.llm, and reflects live updates (27 parametrized cases across 9 classes × 3 tests).

  • tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py (E2E, requires
    GPU): starts trtllm-serve with Qwen3-VL-8B-Instruct and
    guided_decoding_backend: xgrammar, sends a multimodal request with
    response_format: json_schema, and validates the response against the schema. Without the
    fix the server crashes at startup and the fixture itself fails.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
  • Test cases are provided for new code paths (see test instructions)
  • Any new dependencies have been scanned for license and vulnerabilities
  • CODEOWNERS updated if ownership changes
  • Documentation updated as needed
  • Update tava architecture diagram if there is a significant design change in PR.
  • The reviewers assigned automatically/manually are appropriate for the PR.
  • Please check this after reviewing the above items as appropriate for this PR.

@stefanpantic stefanpantic marked this pull request as ready for review March 17, 2026 13:33
@stefanpantic stefanpantic requested review from a team as code owners March 17, 2026 13:33
@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Mar 17, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 17, 2026

📝 Walkthrough

Walkthrough

The changes add a new public vocab_size_padded property to nine VLM model classes, each delegating to the underlying llm.vocab_size_padded attribute. A comprehensive test module validates this property across all VLM classes, and a regression test for guided decoding in VLM is introduced.

Changes

Cohort / File(s) Summary
VLM Model Properties
tensorrt_llm/_torch/models/modeling_gemma3vl.py, modeling_hyperclovax.py, modeling_llava_next.py, modeling_mistral.py, modeling_nemotron_nano.py, modeling_phi4mm.py, modeling_qwen2vl.py, modeling_qwen3vl.py, modeling_vila.py
Added public vocab_size_padded property to each VLM class, delegating to self.llm.vocab_size_padded. Enables external access to padded vocabulary size without altering existing logic.
VLM Property Tests
tests/unittest/_torch/models/test_vlm_vocab_size_padded.py
New test module validating that vocab_size_padded is a property in each VLM class, delegates to llm.vocab_size_padded, and does not cache values across runtime changes.
Guided Decoding Test
tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py
Regression test for VLM guided decoding with JSON schema validation. Includes fixtures for temporary config files and remote server, plus test validating chat completion responses conform to JSON schema.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main fix: adding the missing vocab_size_padded property to VLM classes to resolve guided decoding startup crashes.
Description check ✅ Passed PR description comprehensively explains the problem, solution, affected classes, and test coverage with clear examples and rationale.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unittest/_torch/models/test_vlm_vocab_size_padded.py`:
- Around line 1-14: Update the file header copyright year range from "2022-2024"
to "2022-2026" in the top-of-file license block so it reads "Copyright (c)
2022-2026 NVIDIA CORPORATION & AFFILIATES"; locate the SPDX and license comment
block at the start of the file (the header containing "SPDX-FileCopyrightText"
and "SPDX-License-Identifier") and change the year range accordingly without
altering any other license text.

In `@tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py`:
- Around line 1-18: Update the file header: change the copyright year range from
"2022-2024" to "2022-2026" and replace the placeholder issue reference
"https://github.com/NVIDIA/TensorRT-LLM/issues/XXXX" with the actual PR number
"https://github.com/NVIDIA/TensorRT-LLM/issues/12284"; ensure the updated lines
(the SPDX header and the regression-test comment containing the issue link)
reflect these exact substitutions.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ad3859a0-e5a5-42be-8bba-9bfe75d06388

📥 Commits

Reviewing files that changed from the base of the PR and between 4c4b688 and ee0ec72.

📒 Files selected for processing (11)
  • tensorrt_llm/_torch/models/modeling_gemma3vl.py
  • tensorrt_llm/_torch/models/modeling_hyperclovax.py
  • tensorrt_llm/_torch/models/modeling_llava_next.py
  • tensorrt_llm/_torch/models/modeling_mistral.py
  • tensorrt_llm/_torch/models/modeling_nemotron_nano.py
  • tensorrt_llm/_torch/models/modeling_phi4mm.py
  • tensorrt_llm/_torch/models/modeling_qwen2vl.py
  • tensorrt_llm/_torch/models/modeling_qwen3vl.py
  • tensorrt_llm/_torch/models/modeling_vila.py
  • tests/unittest/_torch/models/test_vlm_vocab_size_padded.py
  • tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py

Comment thread tests/unittest/_torch/models/test_vlm_vocab_size_padded.py Outdated
Comment thread tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py Outdated
@stefanpantic stefanpantic force-pushed the users/stefan/fix-vllm-vocab-size branch from ee0ec72 to e6a6e29 Compare March 17, 2026 14:06
@pengbowang-nv
Copy link
Copy Markdown
Collaborator

pengbowang-nv commented Mar 18, 2026

Thank you for your contribution! Hi @NVIDIA/trt-llm-torch-models-vlm-devs (also cc @yechank-nvidia and @chang-l ) , could you please take a look at this PR? I have managed to confirm both the problem and the fix from this PR.

Comment thread tests/unittest/llmapi/apps/_test_openai_chat_vlm_guided_decoding.py Outdated
@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40921 [ run ] triggered by Bot. Commit: 1f8ad5e Link to invocation

Copy link
Copy Markdown
Collaborator

@yechank-nvidia yechank-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the work!

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40921 [ run ] completed with state SUCCESS. Commit: 1f8ad5e
/LLM/main/L0_MergeRequest_PR pipeline #31918 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pengbowang-nv pengbowang-nv force-pushed the users/stefan/fix-vllm-vocab-size branch from 1f8ad5e to afad089 Compare April 1, 2026 03:24
@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41097 [ run ] triggered by Bot. Commit: afad089 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41097 [ run ] completed with state FAILURE. Commit: afad089
/LLM/main/L0_MergeRequest_PR pipeline #32071 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41162 [ run ] triggered by Bot. Commit: afad089 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41162 [ run ] completed with state SUCCESS. Commit: afad089
/LLM/main/L0_MergeRequest_PR pipeline #32131 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…ided decoding startup crash

Code review update

Signed-off-by: Stefan Pantic <stefanpantic13@gmail.com>
@pengbowang-nv pengbowang-nv force-pushed the users/stefan/fix-vllm-vocab-size branch from afad089 to 0c4d84a Compare April 2, 2026 03:34
@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41324 [ run ] triggered by Bot. Commit: 0c4d84a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41324 [ run ] completed with state SUCCESS. Commit: 0c4d84a
/LLM/main/L0_MergeRequest_PR pipeline #32273 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41451 [ run ] triggered by Bot. Commit: 0c4d84a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41451 [ run ] completed with state SUCCESS. Commit: 0c4d84a
/LLM/main/L0_MergeRequest_PR pipeline #32382 completed with status: 'SUCCESS'

CI Report

Link to invocation

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

Hi, @stefanpantic . I found tests/unittest/_torch/models/test_vlm_vocab_size_padded.py rather trivial and would like to remove the file, what do you think? Thanks!

@stefanpantic
Copy link
Copy Markdown
Contributor Author

stefanpantic commented Apr 3, 2026

@pengbowang-nv No objections on my end. Should I remove or will you do it?

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

Hi @stefanpantic , could you please remove it and I'll continue with CI and merge after that? Thanks!

Signed-off-by: Stefan Pantic <stefanpantic13@gmail.com>
@stefanpantic
Copy link
Copy Markdown
Contributor Author

stefanpantic commented Apr 7, 2026

@pengbowang-nv done ✔️

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42074 [ run ] triggered by Bot. Commit: df8ed08 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42074 [ run ] completed with state SUCCESS. Commit: df8ed08
/LLM/main/L0_MergeRequest_PR pipeline #32914 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42233 [ run ] triggered by Bot. Commit: df8ed08 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42233 [ run ] completed with state SUCCESS. Commit: df8ed08
/LLM/main/L0_MergeRequest_PR pipeline #33044 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42304 [ run ] triggered by Bot. Commit: df8ed08 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42304 [ run ] completed with state FAILURE. Commit: df8ed08
/LLM/main/L0_MergeRequest_PR pipeline #33098 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42353 [ run ] triggered by Bot. Commit: df8ed08 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42353 [ run ] completed with state FAILURE. Commit: df8ed08
/LLM/main/L0_MergeRequest_PR pipeline #33138 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42434 [ run ] triggered by Bot. Commit: df8ed08 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42434 [ run ] completed with state SUCCESS. Commit: df8ed08
/LLM/main/L0_MergeRequest_PR pipeline #33203 completed with status: 'SUCCESS'

CI Report

Link to invocation

@pengbowang-nv pengbowang-nv merged commit 2dff089 into NVIDIA:main Apr 9, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.