[#11368][fix] FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121)#12141
[#11368][fix] FP4 CUTLASS GEMM shared memory overflow on GB10 (SM121)#12141pamelap-nvidia merged 11 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom mihai-chiorean:fix/fp4-gemm-sm121-smem-overflowmihai-chiorean/TensorRT-LLM:fix/fp4-gemm-sm121-smem-overflowCopy head branch name to clipboard
Conversation
The nvfp4_gemm_cutlass kernel fails on GB10 (SM121 / DGX Spark) because SM121 is routed to SM120 tile configurations (128x128x256B and 256x128x128B) that require more shared memory than GB10 provides (~99 KiB vs ~228 KiB on B200). The 128x128x128B tile config was already compiled for all output types (bf16, fp16, fp32) but unreachable at runtime due to: 1. Missing dispatch case in dispatchNVFP4xNVFP4GemmCTAShapeSm120 2. Commented out in getConfigs() candidate list 3. getDefaultGemmConfig() unconditionally selecting the larger tile This patch: - Adds CtaShape128x128x128B to the SM120 dispatch switch - Enables it as a candidate in getConfigs() for profiling - Selects tile config at runtime based on cudaDevAttrMaxSharedMemoryPerBlockOptin rather than compile-time assumptions, so the binary works on both B200 (~228 KiB) and GB10 (~99 KiB) On B200, behavior is unchanged: the larger tiles are still preferred. On GB10, the 128x128x128B tile is selected (the only one that fits), and the runtime SMEM check rejects oversized configs during profiling. Fixes NVIDIA#11368 Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
📝 WalkthroughWalkthroughUpdates FP4 GEMM kernel dispatch and configuration selection to support both large and smaller tile shapes. The kernel routing gains a new 128x128x128B CTA dispatch path, while runtime configuration logic selects tile size based on available shared memory, choosing 128x128x256B when sufficient memory is available or 128x128x128B otherwise. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip CodeRabbit can scan for known vulnerabilities in your dependencies using OSV Scanner.OSV Scanner will automatically detect and report security vulnerabilities in your project's dependencies. No additional configuration is required. |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
cpp/tensorrt_llm/thop/fp4Gemm.cpp (1)
69-72: Static variable naming should use 's' prefix per coding guidelines.The static local variable
maxSmemshould be namedsMaxSmemto follow the project convention for locally visible static variables.Proposed fix
- static int const maxSmem = tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin(); + static int const sMaxSmem = tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin(); constexpr int kMinSmemForLargeTile = 105 * 1024; - auto tileConfig = maxSmem > kMinSmemForLargeTile ? tkc::CutlassTileConfigSM120::CtaShape128x128x256B - : tkc::CutlassTileConfigSM120::CtaShape128x128x128B; + auto tileConfig = sMaxSmem > kMinSmemForLargeTile ? tkc::CutlassTileConfigSM120::CtaShape128x128x256B + : tkc::CutlassTileConfigSM120::CtaShape128x128x128B;As per coding guidelines: "Locally visible static variables in C++ should use camel case with lowercase prefix 's' as the first letter (e.g.,
static std::once_flag sFlag;)".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/thop/fp4Gemm.cpp` around lines 69 - 72, The static local variable maxSmem should follow the project's naming convention for locally visible statics; rename maxSmem to sMaxSmem where it is declared and used (the call to tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin() and the ternary that selects tkc::CutlassTileConfigSM120::CtaShape128x128x256B vs CtaShape128x128x128B for tileConfig) so all references to maxSmem are updated to sMaxSmem and nothing else changes in the logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/thop/fp4Gemm.cpp`:
- Line 69: The static local maxSmem in getDefaultGemmConfig() caches a
device-specific value from getMaxSharedMemoryPerBlockOptin() and will be wrong
after CUDA device switches; remove the static qualifier so maxSmem is queried
each call (i.e., call getMaxSharedMemoryPerBlockOptin() directly when computing
tile config in getDefaultGemmConfig()), or implement a per-device cache keyed by
cudaGetDevice() that stores/get updates per-device SMEM values using the same
getMaxSharedMemoryPerBlockOptin() call; update references to maxSmem accordingly
to ensure the chosen tile configs match the current device.
---
Nitpick comments:
In `@cpp/tensorrt_llm/thop/fp4Gemm.cpp`:
- Around line 69-72: The static local variable maxSmem should follow the
project's naming convention for locally visible statics; rename maxSmem to
sMaxSmem where it is declared and used (the call to
tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin() and the ternary that
selects tkc::CutlassTileConfigSM120::CtaShape128x128x256B vs
CtaShape128x128x128B for tileConfig) so all references to maxSmem are updated to
sMaxSmem and nothing else changes in the logic.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e590b607-1364-48db-b554-fe053a242e1d
📒 Files selected for processing (2)
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.hcpp/tensorrt_llm/thop/fp4Gemm.cpp
Address code review feedback: the static local cached maxSmem per-process rather than per-device, which would produce incorrect tile selection after a CUDA device switch in multi-GPU scenarios. Query on each call instead. Fixes NVIDIA#11368 Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
|
a small nitpick: sm120 and sm121 have the same amount of SMEM (~99KB). GB200 is sm100 and has 228KB. |
SM120 is RTX 5090 (not B200), and SM121 is GB10/DGX Spark. Both SM12x variants have ~99 KiB SMEM per block. B200 is SM100 with ~227 KiB. The runtime logic was already correct (queries actual device SMEM), but the comment incorrectly stated SM120 = B200 with 228 KiB. Fixes NVIDIA#11368 Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Good catch, thanks! I've already corrected this in ef554d9 — the comment now correctly states that all SM12x devices (SM120 and SM121) share the same ~99 KiB SMEM limit, and that SM100 (B200/B100) has ~227 KiB. |
|
Hi @pamelap-nvidia , could you take a look at this DGX Spark issue or reassign? Thanks! |
|
/bot run |
|
PR_Github #39001 [ run ] triggered by Bot. Commit: |
|
PR_Github #39001 [ run ] completed with state |
Use CtaShape128x128x128B directly as the default config instead of querying SMEM at runtime. The autotuner will select the optimal tile (including larger ones) for the current device. Addresses review feedback from @pamelap-nvidia. Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
|
/bot run |
|
PR_Github #39144 [ run ] triggered by Bot. Commit: |
|
PR_Github #39144 [ run ] completed with state
|
|
CI blocked. I'll wait for the CI fix PR in and try again tomorrow. Meanwhile, @mihai-chiorean could you help fix the PR check? It's now blocked by the checklist in the description. We can either mark them as done or remove them. |
|
/bot run |
|
PR_Github #39279 [ run ] triggered by Bot. Commit: |
done. |
|
PR_Github #39279 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39308 [ run ] triggered by Bot. Commit: |
@pamelap-nvidia CI failed and I can't see why unfortunately. |
|
PR_Github #39308 [ run ] completed with state |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot reuse-pipeline |
|
PR_Github #39370 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #39370 [ reuse-pipeline ] completed with state |
|
/bot reuse-pipeline |
|
PR_Github #39372 [ reuse-pipeline ] triggered by Bot. Commit: |
|
PR_Github #39372 [ reuse-pipeline ] completed with state |
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090). The underlying CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4 GEMM shared-memory overflow on SM121, so the Python-side SM version checks were the only remaining barrier. Changes: - tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers - fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend can_implement() SM set {100,103} -> {100,103,120,121} - fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121; use is_blackwell() for FP8 scale layout (shared across Blackwell) - model_config.py: route SM120/121 to TRTLLM backend in resolve_moe_backend() (was falling back to CUTLASS) - torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check - tests/integration/defs/conftest.py: add matching test helpers Signed-off-by: Mihai <mihai@dgx-spark> Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
…SM121) (NVIDIA#12141) Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com> Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090). The underlying CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4 GEMM shared-memory overflow on SM121, so the Python-side SM version checks were the only remaining barrier. Changes: - tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers - fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend can_implement() SM set {100,103} -> {100,103,120,121} - fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121; use is_blackwell() for FP8 scale layout (shared across Blackwell) - model_config.py: route SM120/121 to TRTLLM backend in resolve_moe_backend() (was falling back to CUTLASS) - torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check - tests/integration/defs/conftest.py: add matching test helpers Signed-off-by: Mihai <mihai@dgx-spark> Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090). The underlying CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4 GEMM shared-memory overflow on SM121, so the Python-side SM version checks were the only remaining barrier. Changes: - tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers - fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend can_implement() SM set {100,103} -> {100,103,120,121} - fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121; use is_blackwell() for FP8 scale layout (shared across Blackwell) - model_config.py: route SM120/121 to TRTLLM backend in resolve_moe_backend() (was falling back to CUTLASS) - torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check - tests/integration/defs/conftest.py: add matching test helpers Signed-off-by: Mihai <mihai@dgx-spark> Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
…SM121) (NVIDIA#12141) Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com> Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
Remove the NotImplementedError gate in TRTLLMGenFusedMoE.__init__ that blocked ALL MoE models on SM120+ (DGX Spark / RTX 5090). The underlying CUTLASS kernels already have SM120 templates and PR NVIDIA#12141 fixed the FP4 GEMM shared-memory overflow on SM121, so the Python-side SM version checks were the only remaining barrier. Changes: - tensorrt_llm/_utils.py: add is_sm_120f() and is_blackwell() helpers - fused_moe_trtllm_gen.py: remove __init__ SM>=120 gate; extend can_implement() SM set {100,103} -> {100,103,120,121} - fused_moe_cute_dsl.py: extend NVFP4 SM check to include 120/121; use is_blackwell() for FP8 scale layout (shared across Blackwell) - model_config.py: route SM120/121 to TRTLLM backend in resolve_moe_backend() (was falling back to CUTLASS) - torch_custom_ops.py: extend CuTE DSL NVFP4 dense GEMM SM check - tests/integration/defs/conftest.py: add matching test helpers Signed-off-by: Mihai <mihai@dgx-spark> Signed-off-by: Mihai Chiorean <mihai.v.chiorean@gmail.com>
Summary
Fixes #11368
nvfp4_gemm_cutlassfails on SM12x devices (SM120/RTX 5090 and SM121/GB10/DGX Spark) withkErrorInternalbecause the SM120 tile configurations (128x128x256B and 256x128x128B) require more shared memory than any SM12x device provides (~99 KiB). These tile sizes were designed for SM100 (B200/B100, ~227 KiB SMEM) and do not fit on SM12x.The
CtaShape128x128x128Btile config was already compiled for all output types (bf16, fp16, fp32) infp4_gemm_bf16.cu,fp4_gemm_fp16.cu, andfp4_gemm_fp32.cu— but unreachable at runtime.This patch makes three changes across two files:
fp4_gemm_template.h: AddsCtaShape128x128x128Bdispatch case indispatchNVFP4xNVFP4GemmCTAShapeSm120and enables it in thegetConfigs()candidate list for profilingfp4Gemm.cpp: UsesCtaShape128x128x128Bas the default tile config for SM120+, letting the autotuner select the optimal config for the current deviceOn SM100 (B200/B100, ~227 KiB SMEM), this code path is not reached — SM100 uses its own tile configs via a separate branch. On SM12x (~99 KiB SMEM), the 128x128x128B tile is selected as default, and the autotuner profiles all candidates including larger tiles where they fit.
Background
cudaDevAttrMaxSharedMemoryPerBlockOptin= 101,376 bytes (99 KiB), confirmed by thestatic_assert(smemSize <= 99 * 1024)inmla_sm120.cuStageCountAutoCarveoutsizes pipeline stages at compile time using the archSharedMemoryCapacity, so the resultingSharedStoragestruct exceeds SM12x limits — this cannot be fixed at runtime without a different tile configblackwell_geforce_gemm) running FP4 at 41.6 TFLOPS on GB10, proving this is not a hardware limitationTest Coverage
tests/unittest/_torch/thop/parallel/test_fp4_gemm_quantize.py::TestFunctional::test_fp4_quantize_gemm_torch— validates FP4 GEMM quantize correctnesstests/unittest/_torch/thop/parallel/test_fp4_bmm_quantize.py::TestFunctional::test_fp4_bmm_torch— validates FP4 batched matmultests/unittest/_torch/thop/parallel/test_fp4_linear.py::test_fp4_linear— validates FP4 linear layer end-to-endTest plan (GB10 / SM121)
test_fp4_quantize_gemm_torch— 2/2 passedtest_fp4_bmm_torch— 4/4 passedtest_fp4_linear— 8/8 passedAll 14 tests passed on DGX Spark (GB10 / SM121) with TRT-LLM 1.3.0rc8, CUDA 13.1, PyTorch 2.10.0a0.
Note: SM100 (B200) regression testing deferred to NVIDIA CI — no B200 hardware available to the author.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.