Summary
nvfp4_gemm_cutlass fails on GB10 (SM121) with "Error Internal no error". The SM120 tile configs require more shared memory than GB10 provides. nvfp4_gemm_cublaslt works fine on the same hardware.
Problem: fp4_gemm_template.h routes SM121 to the SM120 path (line ~378: mSm == 120 || mSm == 121), but both SM120 tile configs (128×128×256B and 256×128×128B) need >99 KiB SMEM. GB10 only has 99 KiB vs B200's ~228 KiB.
Proof it's not a hardware limitation: CUTLASS example 79 (blackwell_geforce_gemm) uses GB10-appropriate tiles and runs FP4 successfully at 41.6 TFLOPS.
Environment
- GPU: NVIDIA GB10 (SM 12.1, DGX Spark)
- TRT-LLM: 1.3.0rc2 (commit f42a6cb)
- CUDA: 12.8
- PyTorch: 2.7 (CUDA 12.6)
Failure Path
fp4_gemm_template.h:378 — SM121 dispatches to SM120 path
- SM120 tile configs both require >99 KiB SMEM
gemm_universal_adapter.h:342-353 — cudaFuncSetAttribute(MaxDynamicSharedMemorySize) fails
- Returns
kErrorInternal → "Error Internal no error"
Benchmarks on GB10
| Backend |
Peak TFLOPS |
Status |
cuBLASLt FP4 (nvfp4_gemm_cublaslt) |
99.6 |
✅ Works |
| CUTLASS example 79 (SM121 tiles) |
41.6 |
✅ Works |
TRT-LLM CUTLASS FP4 (nvfp4_gemm_cutlass) |
— |
❌ SMEM overflow |
| BF16 baseline |
8.4 |
✅ Reference |
CUTLASS example 79 results:
| Problem Size |
Avg Runtime |
TFLOPS |
| 3072×3072×3072 |
1.57 ms |
37.0 |
| 4096×4096×4096 |
3.35 ms |
41.0 |
| 2048×5120×8192 |
4.13 ms |
41.6 |
Additional Blockers
Two other FP4 backends are also blocked on GB10:
-
CuteDSL: cute_dsl_custom_ops.py:766 gates on sm_version not in (100, 103). Patching that reveals a second gate in the nvidia-cutlass-dsl package — MmaMXF4NVF4Op at cutlass/cute/nvgpu/tcgen05/mma.py:1271 hardcodes sm_100a/sm_103a only.
-
CUDA Core FP4: No dispatch path for SM121.
Suggested Fix
Add GB10-specific tile configs to fp4_gemm_template.h with fewer pipeline stages to fit 99 KiB. Example 79's GeForce tile configs could serve as reference.
Alternatively, StageCountAutoCarveout could query runtime SMEM via cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin) instead of using compile-time SM120 assumptions — this would make the same binary work on both B200 and GB10 by adjusting pipeline depth.
Reproduction
import torch
import tensorrt_llm
a = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda")
b = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda")
global_sf = torch.tensor(1.0, dtype=torch.float32, device="cuda")
a_q, a_sf = torch.ops.trtllm.fp4_quantize(a, global_sf, 16)
b_q, b_sf = torch.ops.trtllm.fp4_quantize(b, global_sf, 16)
# Works:
result = torch.ops.trtllm.nvfp4_gemm_cublaslt(a_q, a_sf, b_q, b_sf, global_sf, 256, 256, 512)
# Fails with "Error Internal no error":
result = torch.ops.trtllm.nvfp4_gemm_cutlass(a_q, a_sf, b_q, b_sf, global_sf, 256, 256, 512)
Relevant Files
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h — dispatch + tile configs
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/nvfp4_nvfp4_gemm_template_sm120.h — SM120 kernel
cpp/build/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h:342-353 — failure point
cpp/build/_deps/cutlass-src/examples/79_blackwell_geforce_gemm/ — working GB10 FP4 reference
Tested on: NVIDIA GB10 (DGX Spark, compute 12.1)
Summary
nvfp4_gemm_cutlassfails on GB10 (SM121) with"Error Internal no error". The SM120 tile configs require more shared memory than GB10 provides.nvfp4_gemm_cublasltworks fine on the same hardware.Problem:
fp4_gemm_template.hroutes SM121 to the SM120 path (line ~378:mSm == 120 || mSm == 121), but both SM120 tile configs (128×128×256B and 256×128×128B) need >99 KiB SMEM. GB10 only has 99 KiB vs B200's ~228 KiB.Proof it's not a hardware limitation: CUTLASS example 79 (
blackwell_geforce_gemm) uses GB10-appropriate tiles and runs FP4 successfully at 41.6 TFLOPS.Environment
Failure Path
fp4_gemm_template.h:378— SM121 dispatches to SM120 pathgemm_universal_adapter.h:342-353—cudaFuncSetAttribute(MaxDynamicSharedMemorySize)failskErrorInternal→"Error Internal no error"Benchmarks on GB10
nvfp4_gemm_cublaslt)nvfp4_gemm_cutlass)CUTLASS example 79 results:
Additional Blockers
Two other FP4 backends are also blocked on GB10:
CuteDSL:
cute_dsl_custom_ops.py:766gates onsm_version not in (100, 103). Patching that reveals a second gate in the nvidia-cutlass-dsl package —MmaMXF4NVF4Opatcutlass/cute/nvgpu/tcgen05/mma.py:1271hardcodessm_100a/sm_103aonly.CUDA Core FP4: No dispatch path for SM121.
Suggested Fix
Add GB10-specific tile configs to
fp4_gemm_template.hwith fewer pipeline stages to fit 99 KiB. Example 79's GeForce tile configs could serve as reference.Alternatively,
StageCountAutoCarveoutcould query runtime SMEM viacudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin)instead of using compile-time SM120 assumptions — this would make the same binary work on both B200 and GB10 by adjusting pipeline depth.Reproduction
Relevant Files
cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/fp4_gemm_template.h— dispatch + tile configscpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/nvfp4_nvfp4_gemm_template_sm120.h— SM120 kernelcpp/build/_deps/cutlass-src/include/cutlass/gemm/device/gemm_universal_adapter.h:342-353— failure pointcpp/build/_deps/cutlass-src/examples/79_blackwell_geforce_gemm/— working GB10 FP4 referenceTested on: NVIDIA GB10 (DGX Spark, compute 12.1)