Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[#11368][fix] FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121)#12141

Merged
pamelap-nvidia merged 11 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
mihai-chiorean:fix/fp4-gemm-sm121-smem-overflowmihai-chiorean/TensorRT-LLM:fix/fp4-gemm-sm121-smem-overflowCopy head branch name to clipboard
Mar 18, 2026
Merged

[#11368][fix] FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121)#12141
pamelap-nvidia merged 11 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
mihai-chiorean:fix/fp4-gemm-sm121-smem-overflowmihai-chiorean/TensorRT-LLM:fix/fp4-gemm-sm121-smem-overflowCopy head branch name to clipboard

Conversation

@mihai-chiorean
Copy link
Copy Markdown
Contributor

@mihai-chiorean mihai-chiorean commented Mar 12, 2026

Summary

Fixes #11368

nvfp4_gemm_cutlass fails on SM12x devices (SM120/RTX 5090 and SM121/GB10/DGX Spark) with kErrorInternal because the SM120 tile configurations (128x128x256B and 256x128x128B) require more shared memory than any SM12x device provides (~99 KiB). These tile sizes were designed for SM100 (B200/B100, ~227 KiB SMEM) and do not fit on SM12x.

The CtaShape128x128x128B tile config was already compiled for all output types (bf16, fp16, fp32) in fp4_gemm_bf16.cu, fp4_gemm_fp16.cu, and fp4_gemm_fp32.cu — but unreachable at runtime.

This patch makes three changes across two files:

  • fp4_gemm_template.h: Adds CtaShape128x128x128B dispatch case in dispatchNVFP4xNVFP4GemmCTAShapeSm120 and enables it in the getConfigs() candidate list for profiling
  • fp4Gemm.cpp: Uses CtaShape128x128x128B as the default tile config for SM120+, letting the autotuner select the optimal config for the current device

On SM100 (B200/B100, ~227 KiB SMEM), this code path is not reached — SM100 uses its own tile configs via a separate branch. On SM12x (~99 KiB SMEM), the 128x128x128B tile is selected as default, and the autotuner profiles all candidates including larger tiles where they fit.

Background

  • All SM12x devices (SM120=RTX 5090, SM121=GB10) report cudaDevAttrMaxSharedMemoryPerBlockOptin = 101,376 bytes (99 KiB), confirmed by the static_assert(smemSize <= 99 * 1024) in mla_sm120.cu
  • SM100 (B200/B100) has ~227 KiB per block — the large SM120 tile configs were likely ported from SM100 without accounting for the smaller SM12x SMEM
  • CUTLASS StageCountAutoCarveout sizes pipeline stages at compile time using the arch SharedMemoryCapacity, so the resulting SharedStorage struct exceeds SM12x limits — this cannot be fixed at runtime without a different tile config
  • The issue reporter demonstrated CUTLASS example 79 (blackwell_geforce_gemm) running FP4 at 41.6 TFLOPS on GB10, proving this is not a hardware limitation

Test Coverage

  • tests/unittest/_torch/thop/parallel/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_gemm_torch — validates FP4 GEMM quantize correctness
  • tests/unittest/_torch/thop/parallel/test_fp4_bmm_quantize.py::TestFunctional::test_fp4_bmm_torch — validates FP4 batched matmul
  • tests/unittest/_torch/thop/parallel/test_fp4_linear.py::test_fp4_linear — validates FP4 linear layer end-to-end

Test plan (GB10 / SM121)

  • test_fp4_quantize_gemm_torch — 2/2 passed
  • test_fp4_bmm_torch — 4/4 passed
  • test_fp4_linear — 8/8 passed

All 14 tests passed on DGX Spark (GB10 / SM121) with TRT-LLM 1.3.0rc8, CUDA 13.1, PyTorch 2.10.0a0.

Note: SM100 (B200) regression testing deferred to NVIDIA CI — no B200 hardware available to the author.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

The nvfp4_gemm_cutlass kernel fails on GB10 (SM121 / DGX Spark) because
SM121 is routed to SM120 tile configurations (128x128x256B and
256x128x128B) that require more shared memory than GB10 provides
(~99 KiB vs ~228 KiB on B200).

The 128x128x128B tile config was already compiled for all output types
(bf16, fp16, fp32) but unreachable at runtime due to:
1. Missing dispatch case in dispatchNVFP4xNVFP4GemmCTAShapeSm120
2. Commented out in getConfigs() candidate list
3. getDefaultGemmConfig() unconditionally selecting the larger tile

This patch:
- Adds CtaShape128x128x128B to the SM120 dispatch switch
- Enables it as a candidate in getConfigs() for profiling
- Selects tile config at runtime based on cudaDevAttrMaxSharedMemoryPerBlockOptin
  rather than compile-time assumptions, so the binary works on both
  B200 (~228 KiB) and GB10 (~99 KiB)

On B200, behavior is unchanged: the larger tiles are still preferred.
On GB10, the 128x128x128B tile is selected (the only one that fits),
and the runtime SMEM check rejects oversized configs during profiling.

Fixes NVIDIA#11368

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 12, 2026

📝 Walkthrough

Walkthrough

Updates FP4 GEMM kernel dispatch and configuration selection to support both large and smaller tile shapes. The kernel routing gains a new 128x128x128B CTA dispatch path, while runtime configuration logic selects tile size based on available shared memory, choosing 128x128x256B when sufficient memory is available or 128x128x128B otherwise.

Changes

Cohort / File(s) Summary
FP4 Kernel Dispatch & Tile Configuration
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h
Enables 128x128x128B CTA shape in SM120 tile list, adds dispatch case routing to 1x1x1 cluster with Int<128>, Int<128>, Int<128> template parameters.
Runtime GEMM Configuration Selection
cpp/tensorrt_llm/thop/fp4Gemm.cpp
Implements memory-aware tile configuration for W4A4_NVFP4_NVFP4 on SM120+: queries max shared memory per block and selects 128x128x256B or 128x128x128B based on 105 KiB threshold.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed All code changes directly address the requirements from issue #11368: adding CtaShape128x128x128B dispatch support in fp4_gemm_template.h and implementing runtime memory-aware tile selection in fp4Gemm.cpp to resolve the SM121 SMEM overflow.
Out of Scope Changes check ✅ Passed All changes are scoped to fixing the FP4 CUTLASS GEMM shared memory issue on GB10; no unrelated modifications to unrelated systems, utilities, or non-FP4 code paths are present.
Title check ✅ Passed The title clearly and specifically identifies the fix being implemented: addressing FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121). It directly corresponds to the main changes across both files.
Description check ✅ Passed PR description comprehensively explains the issue, solution, and testing, matching all template requirements with clear sections and complete checklist.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can scan for known vulnerabilities in your dependencies using OSV Scanner.

OSV Scanner will automatically detect and report security vulnerabilities in your project's dependencies. No additional configuration is required.

@mihai-chiorean mihai-chiorean changed the title fix https://github.com/NVIDIA/TensorRT-LLM/issues/11368: FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121) fix: FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121) Mar 12, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
cpp/tensorrt_llm/thop/fp4Gemm.cpp (1)

69-72: Static variable naming should use 's' prefix per coding guidelines.

The static local variable maxSmem should be named sMaxSmem to follow the project convention for locally visible static variables.

Proposed fix
-            static int const maxSmem = tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin();
+            static int const sMaxSmem = tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin();
             constexpr int kMinSmemForLargeTile = 105 * 1024;
-            auto tileConfig = maxSmem > kMinSmemForLargeTile ? tkc::CutlassTileConfigSM120::CtaShape128x128x256B
-                                                             : tkc::CutlassTileConfigSM120::CtaShape128x128x128B;
+            auto tileConfig = sMaxSmem > kMinSmemForLargeTile ? tkc::CutlassTileConfigSM120::CtaShape128x128x256B
+                                                              : tkc::CutlassTileConfigSM120::CtaShape128x128x128B;

As per coding guidelines: "Locally visible static variables in C++ should use camel case with lowercase prefix 's' as the first letter (e.g., static std::once_flag sFlag;)".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/thop/fp4Gemm.cpp` around lines 69 - 72, The static local
variable maxSmem should follow the project's naming convention for locally
visible statics; rename maxSmem to sMaxSmem where it is declared and used (the
call to tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin() and the ternary
that selects tkc::CutlassTileConfigSM120::CtaShape128x128x256B vs
CtaShape128x128x128B for tileConfig) so all references to maxSmem are updated to
sMaxSmem and nothing else changes in the logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/thop/fp4Gemm.cpp`:
- Line 69: The static local maxSmem in getDefaultGemmConfig() caches a
device-specific value from getMaxSharedMemoryPerBlockOptin() and will be wrong
after CUDA device switches; remove the static qualifier so maxSmem is queried
each call (i.e., call getMaxSharedMemoryPerBlockOptin() directly when computing
tile config in getDefaultGemmConfig()), or implement a per-device cache keyed by
cudaGetDevice() that stores/get updates per-device SMEM values using the same
getMaxSharedMemoryPerBlockOptin() call; update references to maxSmem accordingly
to ensure the chosen tile configs match the current device.

---

Nitpick comments:
In `@cpp/tensorrt_llm/thop/fp4Gemm.cpp`:
- Around line 69-72: The static local variable maxSmem should follow the
project's naming convention for locally visible statics; rename maxSmem to
sMaxSmem where it is declared and used (the call to
tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin() and the ternary that
selects tkc::CutlassTileConfigSM120::CtaShape128x128x256B vs
CtaShape128x128x128B for tileConfig) so all references to maxSmem are updated to
sMaxSmem and nothing else changes in the logic.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e590b607-1364-48db-b554-fe053a242e1d

📥 Commits

Reviewing files that changed from the base of the PR and between 8de01ac and 35e9bfa.

📒 Files selected for processing (2)
  • cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h
  • cpp/tensorrt_llm/thop/fp4Gemm.cpp

Comment thread cpp/tensorrt_llm/thop/fp4Gemm.cpp Outdated
Address code review feedback: the static local cached maxSmem per-process
rather than per-device, which would produce incorrect tile selection after
a CUDA device switch in multi-GPU scenarios. Query on each call instead.

Fixes NVIDIA#11368

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
@eugr
Copy link
Copy Markdown

eugr commented Mar 12, 2026

a small nitpick: sm120 and sm121 have the same amount of SMEM (~99KB). GB200 is sm100 and has 228KB.

SM120 is RTX 5090 (not B200), and SM121 is GB10/DGX Spark. Both SM12x
variants have ~99 KiB SMEM per block. B200 is SM100 with ~227 KiB.

The runtime logic was already correct (queries actual device SMEM), but
the comment incorrectly stated SM120 = B200 with 228 KiB.

Fixes NVIDIA#11368

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Mar 12, 2026
@mihai-chiorean
Copy link
Copy Markdown
Contributor Author

a small nitpick: sm120 and sm121 have the same amount of SMEM (~99KB). GB200 is sm100 and has 228KB.

Good catch, thanks! I've already corrected this in ef554d9 — the comment now correctly states that all SM12x devices (SM120 and SM121) share the same ~99 KiB SMEM limit, and that SM100 (B200/B100) has ~227 KiB.

@pengbowang-nv
Copy link
Copy Markdown
Collaborator

Hi @pamelap-nvidia , could you take a look at this DGX Spark issue or reassign? Thanks!

Comment thread cpp/tensorrt_llm/thop/fp4Gemm.cpp Outdated
@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39001 [ run ] triggered by Bot. Commit: c866cbb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39001 [ run ] completed with state SUCCESS. Commit: c866cbb
/LLM/main/L0_MergeRequest_PR pipeline #30280 completed with status: 'SUCCESS'

CI Report

Link to invocation

@mihai-chiorean mihai-chiorean changed the title fix: FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121) [#11368][fix] FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121) Mar 16, 2026
Use CtaShape128x128x128B directly as the default config instead of
querying SMEM at runtime. The autotuner will select the optimal tile
(including larger ones) for the current device.

Addresses review feedback from @pamelap-nvidia.

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39144 [ run ] triggered by Bot. Commit: ebee0d5 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39144 [ run ] completed with state FAILURE. Commit: ebee0d5
/LLM/main/L0_MergeRequest_PR pipeline #30404 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

CI blocked. I'll wait for the CI fix PR in and try again tomorrow.

Meanwhile, @mihai-chiorean could you help fix the PR check? It's now blocked by the checklist in the description. We can either mark them as done or remove them.

@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39279 [ run ] triggered by Bot. Commit: 832d750 Link to invocation

@mihai-chiorean
Copy link
Copy Markdown
Contributor Author

CI blocked. I'll wait for the CI fix PR in and try again tomorrow.

Meanwhile, @mihai-chiorean could you help fix the PR check? It's now blocked by the checklist in the description. We can either mark them as done or remove them.

done.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39279 [ run ] completed with state SUCCESS. Commit: 832d750
/LLM/main/L0_MergeRequest_PR pipeline #30530 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39308 [ run ] triggered by Bot. Commit: 832d750 Link to invocation

@mihai-chiorean
Copy link
Copy Markdown
Contributor Author

PR_Github #39279 [ run ] completed with state SUCCESS. Commit: 832d750 /LLM/main/L0_MergeRequest_PR pipeline #30530 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@pamelap-nvidia CI failed and I can't see why unfortunately.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39308 [ run ] completed with state SUCCESS. Commit: 832d750
/LLM/main/L0_MergeRequest_PR pipeline #30561 completed with status: 'SUCCESS'

CI Report

Link to invocation

@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

/bot help

@github-actions
Copy link
Copy Markdown

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39370 [ reuse-pipeline ] triggered by Bot. Commit: 3504110 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39370 [ reuse-pipeline ] completed with state SUCCESS. Commit: 3504110
Release Check Pipeline #3398 failed
Reusing PR_Github #39308 for commit 3504110

Link to invocation

@pamelap-nvidia
Copy link
Copy Markdown
Collaborator

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39372 [ reuse-pipeline ] triggered by Bot. Commit: 3504110 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39372 [ reuse-pipeline ] completed with state SUCCESS. Commit: 3504110
Reusing PR_Github #39308 for commit 3504110

Link to invocation

@pamelap-nvidia pamelap-nvidia merged commit a87dd31 into NVIDIA:main Mar 18, 2026
2 checks passed
mihai-chiorean added a commit to mihai-chiorean/TensorRT-LLM that referenced this pull request Mar 18, 2026
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that
blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090).  The underlying
CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4
GEMM shared-memory overflow on SM121, so the Python-side SM version
checks were the only remaining barrier.

Changes:
- tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers
- fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend
  can_implement() SM set {100,103} -> {100,103,120,121}
- fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121;
  use is_blackwell() for FP8 scale layout (shared across Blackwell)
- model_config.py: route SM120/121 to TRTLLM backend in
  resolve_moe_backend() (was falling back to CUTLASS)
- torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check
- tests/integration/defs/conftest.py: add matching test helpers

Signed-off-by: Mihai <mihai@dgx-spark>
Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
limin2021 pushed a commit to limin2021/TensorRT-LLM that referenced this pull request Mar 19, 2026
…SM121) (NVIDIA#12141)

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
mihai-chiorean added a commit to mihai-chiorean/TensorRT-LLM that referenced this pull request Mar 23, 2026
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that
blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090).  The underlying
CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4
GEMM shared-memory overflow on SM121, so the Python-side SM version
checks were the only remaining barrier.

Changes:
- tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers
- fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend
  can_implement() SM set {100,103} -> {100,103,120,121}
- fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121;
  use is_blackwell() for FP8 scale layout (shared across Blackwell)
- model_config.py: route SM120/121 to TRTLLM backend in
  resolve_moe_backend() (was falling back to CUTLASS)
- torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check
- tests/integration/defs/conftest.py: add matching test helpers

Signed-off-by: Mihai <mihai@dgx-spark>
Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
mihai-chiorean added a commit to mihai-chiorean/TensorRT-LLM that referenced this pull request Mar 24, 2026
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that
blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090).  The underlying
CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4
GEMM shared-memory overflow on SM121, so the Python-side SM version
checks were the only remaining barrier.

Changes:
- tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers
- fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend
  can_implement() SM set {100,103} -> {100,103,120,121}
- fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121;
  use is_blackwell() for FP8 scale layout (shared across Blackwell)
- model_config.py: route SM120/121 to TRTLLM backend in
  resolve_moe_backend() (was falling back to CUTLASS)
- torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check
- tests/integration/defs/conftest.py: add matching test helpers

Signed-off-by: Mihai <mihai@dgx-spark>
Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
longcheng-nv pushed a commit to longcheng-nv/TensorRT-LLM that referenced this pull request Mar 31, 2026
…SM121) (NVIDIA#12141)

Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
mihai-chiorean added a commit to mihai-chiorean/TensorRT-LLM that referenced this pull request Apr 2, 2026
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that
blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090).  The underlying
CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4
GEMM shared-memory overflow on SM121, so the Python-side SM version
checks were the only remaining barrier.

Changes:
- tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers
- fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend
  can_implement() SM set {100,103} -> {100,103,120,121}
- fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121;
  use is_blackwell() for FP8 scale layout (shared across Blackwell)
- model_config.py: route SM120/121 to TRTLLM backend in
  resolve_moe_backend() (was falling back to CUTLASS)
- torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check
- tests/integration/defs/conftest.py: add matching test helpers

Signed-off-by: Mihai <mihai@dgx-spark>
Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] FP4 CUTLASS GEMM fails on GB10 (SM121) — shared memory overflow from B200-sized tile configs

6 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.