@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8158
8158
8159
8159
const int nb = n / QK_K;
8160
8160
8161
- #ifdef __ARM_NEON
8161
+ #ifdef __ARM_FEATURE_SVE
8162
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163
+ float sum = 0;
8164
+ svuint8_t m4b = svdup_n_u8(0xf);
8165
+ svint32_t vzero = svdup_n_s32(0);
8166
+ svuint8_t mone = svdup_n_u8(0x30);
8167
+ svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8168
+ svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8169
+
8170
+ for (int i = 0; i < nb; ++i) {
8171
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
8172
+
8173
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8174
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
8175
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
8176
+
8177
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
8178
+
8179
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8180
+ const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8181
+ const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8182
+ const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8183
+ const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8184
+ const svint64_t prod = svdup_n_s64(0);
8185
+ int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8186
+ svdot_s64(prod, q8sums_2, q6scales_2)));
8187
+ int32_t isum = 0;
8188
+
8189
+ switch (vector_length) {
8190
+ case 128:
8191
+ {
8192
+ const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8193
+ const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8194
+ svint32_t isum_tmp = svdup_n_s32(0);
8195
+ for (int j = 0; j < QK_K/128; ++j) {
8196
+ svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8197
+ svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8198
+ qh += 32;
8199
+ svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8200
+ svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8201
+ svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8202
+ svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8203
+ q6 += 64;
8204
+ svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8205
+ svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8206
+ svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8207
+ svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8208
+ q8 += 64;
8209
+
8210
+ q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8211
+ q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8212
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8213
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8214
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8215
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8216
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8217
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8218
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8219
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8220
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8221
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8222
+
8223
+ scale += 4;
8224
+ q8bytes_1 = svld1_s8(pg8_16, q8);
8225
+ q8bytes_2 = svld1_s8(pg8_16, q8+16);
8226
+ q8bytes_3 = svld1_s8(pg8_16, q8+32);
8227
+ q8bytes_4 = svld1_s8(pg8_16, q8+48);
8228
+ q8 += 64;
8229
+
8230
+ q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8231
+ q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8232
+ q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8233
+ q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8234
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8235
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8236
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8237
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8238
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8239
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8240
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8241
+ isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8242
+ scale += 4;
8243
+ }
8244
+ isum += svaddv_s32(pg32_4, isum_tmp);
8245
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8246
+ }
8247
+ break;
8248
+ case 256:
8249
+ case 512:
8250
+ {
8251
+ const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8252
+ const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8253
+ const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8254
+ svint32_t isum_tmp = svdup_n_s32(0);
8255
+ for (int j = 0; j < QK_K/128; j++) {
8256
+ svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8257
+ qh += 32;
8258
+ svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8259
+ svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8260
+ q6 += 64;
8261
+ svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8262
+ svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8263
+ svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8264
+ svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8265
+ q8 += 128;
8266
+ q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8267
+ q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8268
+ q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8269
+ q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8270
+ q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8271
+ q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8272
+ q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8273
+ q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8274
+
8275
+ svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8276
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8277
+ scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8278
+ svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8279
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8280
+ scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8281
+ svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8282
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8283
+ scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8284
+ svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8285
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8286
+ scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8287
+ svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8288
+ svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8289
+ svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8290
+ svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8291
+
8292
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8293
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8294
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8295
+ isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8296
+ scale += 8;
8297
+ }
8298
+ isum += svaddv_s32(pg32_8, isum_tmp);
8299
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
8300
+ }
8301
+ break;
8302
+ default:
8303
+ assert(false && "Unsupported vector length");
8304
+ break;
8305
+ }
8306
+ }
8307
+
8308
+ *s = sum;
8309
+
8310
+ #elif __ARM_NEON
8162
8311
float sum = 0;
8163
8312
8164
8313
const uint8x16_t m4b = vdupq_n_u8(0xF);
0 commit comments