[None][fix] Fix nemotron super MTP crash on SM90#11807
[None][fix] Fix nemotron super MTP crash on SM90#11807mikeiovine merged 5 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom sunnyqgg:nemotron-super-h100sunnyqgg/TensorRT-LLM:nemotron-super-h100Copy head branch name to clipboard
Conversation
📝 WalkthroughWalkthroughThis PR updates flashinfer-python dependency, modifies dtype handling in an argmax kernel, refactors speculative decoding and Mamba2 selective state update logic, adds **kwargs propagation to Nemotron model forward methods, and introduces new NVFP4 MTP integration tests. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/cute_dsl_kernels/argmax.py (1)
650-653:⚠️ Potential issue | 🟡 MinorInconsistent dtype in CUTLASS DSL fallback.
The fallback function when CUTLASS DSL is unavailable still converts indices to
x.dtypeinstead offloat32, which is inconsistent with the mainargmaxfunction's output contract (lines 600-603).🐛 Proposed fix for dtype consistency
else: # Fallback if CUTLASS DSL is not available def argmax(x: torch.Tensor) -> torch.Tensor: """Fallback argmax using PyTorch when CUTLASS DSL is not available.""" max_vals, max_indices = torch.max(x, dim=-1, keepdim=True) - return torch.cat([max_vals, max_indices.to(x.dtype)], dim=-1) + return torch.cat([max_vals.to(torch.float32), max_indices.to(torch.float32)], dim=-1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/cute_dsl_kernels/argmax.py` around lines 650 - 653, The fallback argmax in function argmax currently casts max indices to x.dtype which mismatches the main implementation's contract; change the cast so indices are converted to torch.float32 (not x.dtype) before concatenation with max_vals, ensuring the returned tensor matches the main argmax output dtype/shape (use torch.max(..., keepdim=True) then torch.cat([max_vals, max_indices.to(torch.float32)], dim=-1)).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 5865-5867: Guard against divide-by-zero when computing
accept_rate: before computing accept_rate = num_accepted / num_drafted, check
num_drafted and handle the zero case (e.g., assert with a clear message like "No
drafted tokens for prompt {i}" or skip the prompt) so the code doesn't raise
ZeroDivisionError; update the block around the variables accept_rate,
num_accepted, num_drafted and the prompt index i accordingly.
- Around line 5833-5868: The LLM instance is created without deterministic
cleanup; change the creation of llm_spec to use a context manager (e.g., with
LLM(**llm_common_config, speculative_config=mtp_config) as llm_spec:) so the LLM
is deterministically closed after the test block; update the block that uses
llm_spec.tokenizer, llm_spec.generate_async, and related variables to be inside
that with scope (or alternatively call llm_spec.close() in a finally) to ensure
proper teardown in the integration test.
---
Outside diff comments:
In `@tensorrt_llm/_torch/cute_dsl_kernels/argmax.py`:
- Around line 650-653: The fallback argmax in function argmax currently casts
max indices to x.dtype which mismatches the main implementation's contract;
change the cast so indices are converted to torch.float32 (not x.dtype) before
concatenation with max_vals, ensuring the returned tensor matches the main
argmax output dtype/shape (use torch.max(..., keepdim=True) then
torch.cat([max_vals, max_indices.to(torch.float32)], dim=-1)).
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
requirements.txttensorrt_llm/_torch/cute_dsl_kernels/argmax.pytensorrt_llm/_torch/models/modeling_nemotron_h.pytensorrt_llm/_torch/modules/mamba/mamba2_mixer.pytensorrt_llm/_torch/speculative/mtp.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/test-db/l0_dgx_b200.yml
46cb498 to
2aff41d
Compare
2aff41d to
8e09f8e
Compare
|
/bot run |
|
PR_Github #37270 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #37270 [ run ] completed with state
|
|
PR_Github #37274 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #37276 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #37282 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #37417 [ run ] triggered by Bot. Commit: |
Signed-off-by: qgai <qgai@nvidia.com>
|
/bot run |
|
PR_Github #37423 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #37434 [ run ] triggered by Bot. Commit: |
…e test Signed-off-by: qgai <qgai@nvidia.com>
de34a7f to
ab64049
Compare
|
PR_Github #37434 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37523 [ run ] triggered by Bot. Commit: |
|
PR_Github #37523 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37588 [ run ] triggered by Bot. Commit: |
|
PR_Github #37588 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37626 [ run ] triggered by Bot. Commit: |
|
PR_Github #37626 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37640 [ run ] triggered by Bot. Commit: |
|
PR_Github #37640 [ run ] completed with state
|
|
/bot run |
|
PR_Github #37779 [ run ] triggered by Bot. Commit: |
|
PR_Github #37779 [ run ] completed with state |
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Signed-off-by: qgai <qgai@nvidia.com>
Summary
_torch/speculative/mtp.pyenable_attention_dptest_llm_api_pytorch.py)Changes
tensorrt_llm/_torch/speculative/mtp.py: Fix MTP speculative decoding crash on SM90tensorrt_llm/_torch/models/modeling_nemotron_h.py: Fix MTP with attention DP enabledtensorrt_llm/_torch/cute_dsl_kernels/argmax.py: Related kernel fixestests/integration/defs/accuracy/test_llm_api_pytorch.py: Add Nemotron accuracy teststests/integration/test_lists/: Update test lists and DB configsrequirements.txt: Update dependenciesTest plan
enable_attention_dpworks correctlySummary by CodeRabbit
New Features
Bug Fixes
Performance Improvements
Chores