Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d9a1452

Browse filesBrowse files
authored
ggml : add SVE support for q6_K_q8_K (ggml-org#12361)
1 parent fd123cf commit d9a1452
Copy full SHA for d9a1452

File tree

Expand file treeCollapse file tree

1 file changed

+150
-1
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+150
-1
lines changed

‎ggml/src/ggml-cpu/ggml-cpu-quants.c

Copy file name to clipboardExpand all lines: ggml/src/ggml-cpu/ggml-cpu-quants.c
+150-1Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
81588158

81598159
const int nb = n / QK_K;
81608160

8161-
#ifdef __ARM_NEON
8161+
#ifdef __ARM_FEATURE_SVE
8162+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163+
float sum = 0;
8164+
svuint8_t m4b = svdup_n_u8(0xf);
8165+
svint32_t vzero = svdup_n_s32(0);
8166+
svuint8_t mone = svdup_n_u8(0x30);
8167+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8168+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8169+
8170+
for (int i = 0; i < nb; ++i) {
8171+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
8172+
8173+
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8174+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
8175+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
8176+
8177+
const int8_t * GGML_RESTRICT scale = x[i].scales;
8178+
8179+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8180+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8181+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8182+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8183+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8184+
const svint64_t prod = svdup_n_s64(0);
8185+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8186+
svdot_s64(prod, q8sums_2, q6scales_2)));
8187+
int32_t isum = 0;
8188+
8189+
switch (vector_length) {
8190+
case 128:
8191+
{
8192+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8193+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8194+
svint32_t isum_tmp = svdup_n_s32(0);
8195+
for (int j = 0; j < QK_K/128; ++j) {
8196+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8197+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8198+
qh += 32;
8199+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8200+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8201+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8202+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8203+
q6 += 64;
8204+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8205+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8206+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8207+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8208+
q8 += 64;
8209+
8210+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8211+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8212+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8213+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8214+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8215+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8216+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8217+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8218+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8219+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8220+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8221+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8222+
8223+
scale += 4;
8224+
q8bytes_1 = svld1_s8(pg8_16, q8);
8225+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
8226+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
8227+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
8228+
q8 += 64;
8229+
8230+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8231+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8232+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8233+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8234+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8235+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8236+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8237+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8238+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8239+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8240+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8241+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8242+
scale += 4;
8243+
}
8244+
isum += svaddv_s32(pg32_4, isum_tmp);
8245+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
8246+
}
8247+
break;
8248+
case 256:
8249+
case 512:
8250+
{
8251+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8252+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8253+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8254+
svint32_t isum_tmp = svdup_n_s32(0);
8255+
for (int j = 0; j < QK_K/128; j++) {
8256+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8257+
qh += 32;
8258+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8259+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8260+
q6 += 64;
8261+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8262+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8263+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8264+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8265+
q8 += 128;
8266+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8267+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8268+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8269+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8270+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8271+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8272+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8273+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8274+
8275+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8276+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8277+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8278+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8279+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8280+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8281+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8282+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8283+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8284+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8285+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8286+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8287+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8288+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8289+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8290+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8291+
8292+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8293+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8294+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8295+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8296+
scale += 8;
8297+
}
8298+
isum += svaddv_s32(pg32_8, isum_tmp);
8299+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
8300+
}
8301+
break;
8302+
default:
8303+
assert(false && "Unsupported vector length");
8304+
break;
8305+
}
8306+
}
8307+
8308+
*s = sum;
8309+
8310+
#elif __ARM_NEON
81628311
float sum = 0;
81638312

81648313
const uint8x16_t m4b = vdupq_n_u8(0xF);

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.