Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fac63a3

Browse filesBrowse files
authored
musa: refine compute capability (ggml-org#12493)
* musa: refine compute capability Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Address review comments Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
1 parent eddfb43 commit fac63a3
Copy full SHA for fac63a3

File tree

Expand file treeCollapse file tree

5 files changed

+45
-33
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+45
-33
lines changed

‎ggml/src/ggml-cuda/common.cuh

Copy file name to clipboardExpand all lines: ggml/src/ggml-cuda/common.cuh
+31-19Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,17 @@
4141
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
4242
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
4343

44-
#define GGML_CUDA_CC_PASCAL 600
45-
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46-
#define GGML_CUDA_CC_VOLTA 700
47-
#define GGML_CUDA_CC_TURING 750
48-
#define GGML_CUDA_CC_AMPERE 800
49-
#define GGML_CUDA_CC_ADA_LOVELACE 890
50-
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
51-
44+
#define GGML_CUDA_CC_PASCAL 600
45+
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46+
#define GGML_CUDA_CC_VOLTA 700
47+
#define GGML_CUDA_CC_TURING 750
48+
#define GGML_CUDA_CC_AMPERE 800
49+
#define GGML_CUDA_CC_ADA_LOVELACE 890
50+
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
51+
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
52+
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
53+
54+
// AMD
5255
// GCN/CNDA, wave size is 64
5356
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
5457
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
@@ -70,8 +73,17 @@
7073
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
7174
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7275

73-
#define GGML_CUDA_CC_QY1 210
74-
#define GGML_CUDA_CC_QY2 220
76+
// Moore Threads
77+
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
78+
79+
#define GGML_CUDA_CC_QY1 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80+
#define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81+
#define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
82+
83+
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
84+
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
85+
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
86+
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
7587

7688
#ifdef __CUDA_ARCH_LIST__
7789
constexpr bool ggml_cuda_has_arch_impl(int) {
@@ -209,42 +221,42 @@ typedef float2 dfloat2;
209221
#define CP_ASYNC_AVAILABLE
210222
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
211223

212-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
224+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
213225
#define FLASH_ATTN_AVAILABLE
214-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
226+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
215227

216228
static bool fp16_available(const int cc) {
217229
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
218230
}
219231

220232
static bool fast_fp16_available(const int cc) {
221-
return fp16_available(cc) && cc != 610;
233+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
222234
}
223235

224236
// To be used for feature selection of external libraries, e.g. cuBLAS.
225237
static bool fast_fp16_hardware_available(const int cc) {
226-
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
238+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
227239
}
228240

229241
// Any FP16 tensor core instructions are available for ggml code.
230242
static bool fp16_mma_available(const int cc) {
231243
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
232244
return false;
233245
#else
234-
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
235-
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
246+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
247+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
236248
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
237249
}
238250

239251
// To be used for feature selection of external libraries, e.g. cuBLAS.
240252
static bool fp16_mma_hardware_available(const int cc) {
241-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
242-
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
253+
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA ||
254+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
243255
}
244256

245257
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
246258
static bool new_mma_available(const int cc) {
247-
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
259+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
248260
}
249261

250262
static bool cp_async_available(const int cc) {

‎ggml/src/ggml-cuda/fattn.cu

Copy file name to clipboardExpand all lines: ggml/src/ggml-cuda/fattn.cu
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
253253
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
254254
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
255255

256-
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
256+
if (GGML_CUDA_CC_IS_AMD(cc)) {
257257
#if defined(GGML_HIP_ROCWMMA_FATTN)
258258
if (fp16_mma_available(cc)) {
259259
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);

‎ggml/src/ggml-cuda/ggml-cuda.cu

Copy file name to clipboardExpand all lines: ggml/src/ggml-cuda/ggml-cuda.cu
+5-5Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,9 @@ static ggml_cuda_device_info ggml_cuda_init() {
264264
#elif defined(GGML_USE_MUSA)
265265
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
266266
info.devices[id].warp_size = 32;
267-
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
268267
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
269-
info.devices[id].cc = 100*prop.major + 10*prop.minor;
268+
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
269+
info.devices[id].cc += prop.minor * 0x10;
270270
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
271271
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
272272
#else
@@ -1188,11 +1188,11 @@ static void ggml_cuda_op_mul_mat_cublas(
11881188
// ldc == nrows of the matrix that cuBLAS writes into
11891189
int64_t ldc = id == ctx.device ? ne0 : row_diff;
11901190

1191-
const int compute_capability = ggml_cuda_info().devices[id].cc;
1191+
const int cc = ggml_cuda_info().devices[id].cc;
11921192

11931193
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
11941194

1195-
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
1195+
if (((cc >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
11961196
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
11971197
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
11981198
if (src0->type != GGML_TYPE_F16) {
@@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12161216

12171217
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
12181218

1219-
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
1219+
if (GGML_CUDA_CC_IS_CDNA(cc)) {
12201220
const float alpha = 1.0f;
12211221
const float beta = 0.0f;
12221222
CUBLAS_CHECK(

‎ggml/src/ggml-cuda/mmq.cu

Copy file name to clipboardExpand all lines: ggml/src/ggml-cuda/mmq.cu
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void ggml_cuda_op_mul_mat_q(
2828
// Also its fixup needs to allocate a temporary buffer in the memory pool.
2929
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
3030
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
31-
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
31+
GGML_CUDA_CC_IS_NVIDIA(cc) && src1_ncols == ne11;
3232
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
3333

3434
switch (src0->type) {
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
145145
return true;
146146
#endif //GGML_CUDA_FORCE_MMQ
147147

148-
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
148+
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
149149
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
150150
}
151151

‎ggml/src/ggml-cuda/mmq.cuh

Copy file name to clipboardExpand all lines: ggml/src/ggml-cuda/mmq.cuh
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct tile_x_sizes {
9090

9191
static int get_mmq_x_max_host(const int cc) {
9292
return new_mma_available(cc) ? 128 :
93-
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
93+
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc) ?
9494
#ifdef GGML_CUDA_FORCE_MMQ
9595
128 : 64;
9696
#else
@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
123123
}
124124

125125
static int get_mmq_y_host(const int cc) {
126-
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
127-
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
126+
return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
127+
((ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) ? 128 : 64);
128128
}
129129

130130
static constexpr __device__ int get_mmq_y_device() {
@@ -2772,14 +2772,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
27722772

27732773
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
27742774

2775-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
2775+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
27762776
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
27772777
if (!shmem_limit_raised[id]) {
27782778
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
27792779
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
27802780
shmem_limit_raised[id] = true;
27812781
}
2782-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
2782+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
27832783

27842784
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
27852785
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
@@ -2832,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
28322832
const int mmq_x_max = get_mmq_x_max_host(cc);
28332833
const int mmq_y = get_mmq_y_host(cc);
28342834
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
2835-
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
2835+
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc);
28362836

28372837
int mmq_x_best = 0;
28382838
int nparts_best = INT_MAX;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.