Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

k-quants with super-block size of 64 #2001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Jun 26, 2023
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d2f12ac
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
9fe2a2b
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
1f6195c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
aebd547
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
2b2ab31
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
bcf8c5c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
c6c3536
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
5aae4b8
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
41e46ec
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
460dd84
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
3bd9ae7
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
03f30c8
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
cda47a6
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
80c75fe
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
2b2a13c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
9d27d8d
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
2ff543c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
d92c5a9
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
fae24af
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
e1bbcfc
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
167a0bb
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
6081a65
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
ff83e32
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
285eeb1
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
8b98d01
k_quants: call them _K, not _k, also on Metal
Kawrakow Jun 23, 2023
558a194
k_quants: correctly define QK_K in llama.cpp
Kawrakow Jun 23, 2023
333ffcc
Fixed bug in q4_K quantization added with the 64-block addition
Kawrakow Jun 23, 2023
88412a1
Simplify via lambda
Kawrakow Jun 23, 2023
aeefd4e
k_quants: swicth Q3_K to 4-bit scales when QK_K = 64
Kawrakow Jun 24, 2023
ce19b96
k_quants: switch Q4_K to 4-bit scales when QK_K = 64
Kawrakow Jun 24, 2023
4f61506
k_quants: forgot to add the Metal changes in last commit
Kawrakow Jun 24, 2023
ccf4901
k_quants: change Q5_K to be type 0 when QK_K = 64
Kawrakow Jun 24, 2023
2da3a59
k_quants: AVX2 implementation for new 64-weight Q5_K
Kawrakow Jun 24, 2023
53e81ca
k_quants: 10% faster ARM_NEON Q5_K dot product
Kawrakow Jun 24, 2023
5fd8337
k_quants: fixed issue caused by merging with master
Kawrakow Jun 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
k_quants: WIP super-blocks with 64 weights
Q6_K working on ARM_NEON
  • Loading branch information
Kawrakow committed Jun 26, 2023
commit 03f30c8eca735ca3656a2b53a84da53f688b2e6e
141 changes: 37 additions & 104 deletions 141 k_quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
return scale;
}

#if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
if (j < 4) {
*d = q[j] & 63; *m = q[j + 4] & 63;
Expand All @@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
#endif

//========================- 2-bit (de)-quantization

Expand Down Expand Up @@ -1895,7 +1897,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri

const int nb = n / QK_K;

#ifdef __ARM_NEON
#ifdef z__ARM_NEON

uint32_t aux[3];
uint32_t utmp[4];
Expand Down Expand Up @@ -2361,7 +2363,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri

const int nb = n / QK_K;

#ifdef __ARM_NEON
#ifdef z__ARM_NEON

const uint8x16_t m4b = vdupq_n_u8(0xf);
#ifdef __ARM_FEATURE_DOTPROD
Expand Down Expand Up @@ -2779,7 +2781,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri

const int nb = n / QK_K;

#ifdef __ARM_NEON
#ifdef z__ARM_NEON

const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);
Expand Down Expand Up @@ -3243,7 +3245,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri

const uint8x16_t m4b = vdupq_n_u8(0xF);
const int32x4_t vzero = vdupq_n_s32(0);
//const int8x16_t m32s = vdupq_n_s8(32);
const int8x16_t m32s = vdupq_n_s8(32);

const uint8x16_t mone = vdupq_n_u8(3);

Expand All @@ -3252,124 +3254,55 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri

for (int i = 0; i < nb; ++i) {

const float d_all = ggml_fp16_to_fp32(x[i].d);
const float d_all = (float)x[i].d;

const uint8_t * restrict q6 = x[i].ql;
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;

const int8_t * restrict scale = x[i].scales;

const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
const int8x16_t scales = vld1q_s8(scale);
const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};

const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
int32_t isum_mins = vaddvq_s32(prod);

int32_t isum = 0;

for (int j = 0; j < QK_K/128; ++j) {

uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;

q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[1], 2);
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);

//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));

#if defined(__ARM_FEATURE_DOTPROD)

isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
scale += 4;

#else

int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
scale += 2;

int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
scale += 2;
#endif
uint8x16_t qhbits = vld1q_u8(qh);
uint8x16x2_t q6bits = vld1q_u8_x2(q6);
int8x16x4_t q8bytes = vld1q_s8_x4(q8);

q8bytes = vld1q_s8_x4(q8); q8 += 64;
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits, 4);
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits, 6);
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);

shifted = vshrq_n_u8(qhbits.val[0], 4);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[1], 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[0], 6);
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
shifted = vshrq_n_u8(qhbits.val[1], 6);
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);

//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);

#if defined(__ARM_FEATURE_DOTPROD)

isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
scale += 4;

//for (int l = 0; l < 4; ++l) {
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
// isum += vaddvq_s32(p) * *scale++;
//}
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
#else
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
scale += 2;

p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
scale += 2;
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];

int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
#endif

}
//sum += isum * d_all * y[i].d;
sum += d_all * y[i].d * (isum - 32 * isum_mins);
sum += isum * d_all * y[i].d;

}
*s = sum;
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.