Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit faaaff5

Browse filesBrowse files
authored
CANN: Support MUL_MAT_ID for q8_0 and q4_0 (ggml-org#13705)
* [CANN]Support MUL_MAT_ID Q8 && Q4 Signed-off-by: noemotiovon <757486878@qq.com> * codestyle adjustment Signed-off-by: noemotiovon <757486878@qq.com> --------- Signed-off-by: noemotiovon <757486878@qq.com>
1 parent e16c473 commit faaaff5
Copy full SHA for faaaff5

File tree

Expand file treeCollapse file tree

2 files changed

+142
-6
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+142
-6
lines changed

‎ggml/src/ggml-cann/aclnn_ops.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-cann/aclnn_ops.cpp
+133-6Lines changed: 133 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
26972697
}
26982698
}
26992699

2700-
// GroupedMatmulV2 required tensor_list.size < 128
27012700
size_t GROUP_SIZE = 128;
2702-
std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
2703-
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
2704-
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
2705-
2706-
// split and call GroupedMatmulV2
2701+
// GroupedMatmulV2 required tensor_list.size < 128
27072702
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
2703+
// split and call GroupedMatmulV2
27082704
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
27092705
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
27102706
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
@@ -2722,13 +2718,144 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27222718
return;
27232719
}
27242720

2721+
/**
2722+
* @brief Performs expert-specific matrix multiplication (MoE) with
2723+
* quantized precision using the CANN backend.
2724+
*
2725+
* This function executes a matrix multiplication operation tailored for
2726+
* Mixture of Experts (MoE) models, where the input tensor is multiplied
2727+
* with expert-specific quantized weight matrices. It leverages the CANN
2728+
* backend to perform efficient low-precision computations and stores the
2729+
* quantized result in the destination tensor `dst`.
2730+
*
2731+
* Quantization techniques reduce memory footprint and improve performance
2732+
* by using lower-bit representations (e.g., int8) instead of floating-point.
2733+
* This function is designed to work with such formats and may incorporate
2734+
* optimizations like identity-based fast paths or routing masks for sparse
2735+
* expert selection.
2736+
*
2737+
* @param ctx The context for executing CANN backend operations.
2738+
* @param dst The destination tensor where the quantized MoE multiplication result
2739+
* will be stored.
2740+
*
2741+
* @note This function assumes quantized data types and is designed for
2742+
* MoE architectures with potential sparse expert routing.
2743+
*/
2744+
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2745+
// TODO: Use aclnnGroupedMatMul
2746+
//dst [M, K, N, 1]
2747+
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2748+
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2749+
ggml_tensor * ids = dst->src[2]; //ids [K, N]
2750+
2751+
GGML_TENSOR_BINARY_OP_LOCALS
2752+
2753+
// copy index from npu to cpu
2754+
int64_t n_as = ne02; // A
2755+
int64_t n_ids = ids->ne[0]; // K
2756+
2757+
std::vector<char> ids_host(ggml_nbytes(ids));
2758+
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2759+
ACL_MEMCPY_DEVICE_TO_HOST);
2760+
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2761+
2762+
char * src0_original = (char *) src0->data;
2763+
char * src1_original = (char *) src1->data;
2764+
char * dst_original = (char *) dst->data;
2765+
2766+
ggml_tensor src0_row = *src0;
2767+
ggml_tensor src1_row = *src1;
2768+
ggml_tensor dst_row = *dst;
2769+
2770+
const enum ggml_type type = dst->src[0]->type;
2771+
float weight_elem_size;
2772+
if (type == GGML_TYPE_Q4_0) {
2773+
weight_elem_size = float(sizeof(uint8_t)) / 2;
2774+
} else if (type == GGML_TYPE_Q8_0) {
2775+
weight_elem_size = float(sizeof(uint8_t));
2776+
} else {
2777+
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
2778+
}
2779+
2780+
// src0_row [D, M, 1, 1] weight without permute
2781+
src0_row.ne[2] = 1;
2782+
src0_row.ne[3] = 1;
2783+
src0_row.nb[0] = weight_elem_size;
2784+
src0_row.nb[1] = weight_elem_size * ne00;
2785+
src0_row.nb[2] = weight_elem_size * ne00;
2786+
src0_row.nb[3] = weight_elem_size * ne00;
2787+
size_t weight_stride = ne00 * ne01 * weight_elem_size;
2788+
size_t weight_size = weight_stride * ne02 * ne03;
2789+
2790+
// scale [D, M, 1, 1] -> scale && permute
2791+
size_t scale_elem_size = sizeof(uint16_t);
2792+
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
2793+
2794+
// src1_row [D, 1, 1, 1] -> input
2795+
src1_row.ne[1] = 1;
2796+
src1_row.ne[2] = 1;
2797+
src1_row.ne[3] = 1;
2798+
src1_row.nb[2] = nb11;
2799+
src1_row.nb[3] = nb11;
2800+
2801+
// dst_row [M, 1, 1, 1] -> out
2802+
dst_row.ne[1] = 1;
2803+
dst_row.ne[2] = 1;
2804+
dst_row.ne[3] = 1;
2805+
dst_row.nb[2] = nb1;
2806+
dst_row.nb[3] = nb1;
2807+
2808+
//create weight for one row
2809+
ggml_cann_pool_alloc weight_allocator(ctx.pool());
2810+
void* weight_buffer = weight_allocator.alloc(nb02);
2811+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2812+
for (int64_t id = 0; id < n_ids; id++) {
2813+
// expert index
2814+
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2815+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
2816+
2817+
// If B = 1 (broadcast), always use 0; otherwise, use id.
2818+
int64_t i11 = (ne11 == 1 ? 0 : id);
2819+
int64_t i12 = iid1;
2820+
2821+
int64_t i1 = id;
2822+
int64_t i2 = i12;
2823+
2824+
void* src0_tmp_ptr = src0_original + i02*weight_stride;
2825+
void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
2826+
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2827+
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2828+
2829+
// mem cpy
2830+
ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
2831+
ACL_MEMCPY_DEVICE_TO_DEVICE);
2832+
void* scale_buffer = (char*)weight_buffer + weight_stride;
2833+
ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
2834+
ACL_MEMCPY_DEVICE_TO_DEVICE);
2835+
2836+
src0_row.data = weight_buffer;
2837+
src1_row.data = src1_tmp_ptr;
2838+
dst_row.data = dst_tmp_ptr;
2839+
dst_row.src[0] = &src0_row;
2840+
dst_row.src[1] = &src1_row;
2841+
2842+
ggml_cann_mul_mat(ctx, &dst_row);
2843+
}
2844+
}
2845+
return;
2846+
}
2847+
27252848
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
27262849
const enum ggml_type type = dst->src[0]->type;
27272850
switch (type) {
27282851
case GGML_TYPE_F32:
27292852
case GGML_TYPE_F16:
27302853
ggml_cann_mul_mat_id_fp(ctx, dst);
27312854
break;
2855+
case GGML_TYPE_Q4_0:
2856+
case GGML_TYPE_Q8_0:
2857+
ggml_cann_mul_mat_id_quant(ctx, dst);
2858+
break;
27322859
default:
27332860
GGML_ABORT("Unsupported type for mul_mat_id");
27342861
break;

‎ggml/src/ggml-cann/ggml-cann.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-cann/ggml-cann.cpp
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
20352035
case GGML_TYPE_F16:
20362036
case GGML_TYPE_F32:
20372037
return true;
2038+
case GGML_TYPE_Q8_0:
2039+
case GGML_TYPE_Q4_0:
2040+
#ifdef ASCEND_310P
2041+
// Q4 && Q8 per group is not suppor on 310p device
2042+
return false;
2043+
#endif
2044+
// only support contiguous for quantized types.
2045+
return ggml_is_contiguous(op->src[0]) &&
2046+
ggml_is_contiguous(op->src[1]);
20382047
default:
20392048
return false;
20402049
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.