Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 0f291e1

Browse filesBrowse files
ikawrakowKawrakow
andauthored
metal : Q6_K implementation (abetlen#1752)
* Metal implementation for Q4_K Very slow for now: 42 ms / token, Q4_0 runs in 28 ms/token on my 30-core M2 Max GPU. * Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. * Optimizing q4_K metal dot some more For n = 256 it is now 28.1 ms/token compared to 27 ms/token for q4_0. * Fix after merge with master * Metal implementation for Q6_K Similar to the CUDA implementation. No idea if this is the optimum for Metal, but the few alternative variants I tried all had a lower performance. We get 36.5 ms / token on M2 Max with 30 GPU cores. This corresponds to ~200 GB/second throughput. * clang-tidy : add config back * Much better Q6_K implementation for metal 28.3 ms / token for 7B. Subtracting ~9 ms that is spent in other compute graph operations, we are left with ~19 ms for the matrix multiplications. The model is ~5.5 GB, so we are getting 1000 / 19 * 5.5 = 290 GB/s! --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 8fc8179 commit 0f291e1
Copy full SHA for 0f291e1

File tree

Expand file treeCollapse file tree

2 files changed

+187
-7
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+187
-7
lines changed

‎ggml-metal.m

Copy file name to clipboardExpand all lines: ggml-metal.m
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@
5050
GGML_METAL_DECL_KERNEL(get_rows_f16);
5151
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
5252
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
53+
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5354
GGML_METAL_DECL_KERNEL(rms_norm);
5455
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5556
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
5657
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
58+
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
5759
GGML_METAL_DECL_KERNEL(rope);
5860
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
5961
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -136,10 +138,12 @@
136138
GGML_METAL_ADD_KERNEL(get_rows_f16);
137139
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
138140
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
141+
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
139142
GGML_METAL_ADD_KERNEL(rms_norm);
140143
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
141144
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
142145
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
146+
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
143147
GGML_METAL_ADD_KERNEL(rope);
144148
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
145149
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -530,6 +534,15 @@ void ggml_metal_graph_compute(
530534
nth1 = 16;
531535
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
532536
} break;
537+
case GGML_TYPE_Q6_K:
538+
{
539+
GGML_ASSERT(ne02 == 1);
540+
GGML_ASSERT(ne12 == 1);
541+
542+
nth0 = 4;
543+
nth1 = 16;
544+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
545+
} break;
533546
default:
534547
{
535548
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
@@ -560,6 +573,9 @@ void ggml_metal_graph_compute(
560573
} else if (src0t == GGML_TYPE_Q4_K) {
561574
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
562575
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
576+
} else if (src0t == GGML_TYPE_Q6_K) {
577+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
578+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
563579
} else {
564580
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
565581
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -576,6 +592,7 @@ void ggml_metal_graph_compute(
576592
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
577593
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
578594
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
595+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
579596
default: GGML_ASSERT(false && "not implemented");
580597
}
581598

‎ggml-metal.metal

Copy file name to clipboardExpand all lines: ggml-metal.metal
+170-7Lines changed: 170 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,18 +303,37 @@ kernel void kernel_mul_mat_q4_0_f32(
303303
sum[ith] += acc*d;
304304
}
305305

306-
// accumulate the sum from all threads in the threadgroup
306+
//
307+
// Accumulate the sum from all threads in the threadgroup
308+
// This version is slightly faster than the commented out one below,
309+
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
310+
//
307311
threadgroup_barrier(mem_flags::mem_threadgroup);
308-
for (uint i = nth/2; i > 0; i /= 2) {
309-
if (ith < i) {
310-
sum[ith] += sum[ith + i];
311-
}
312-
threadgroup_barrier(mem_flags::mem_threadgroup);
312+
if (ith%4 == 0) {
313+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
313314
}
314-
315+
threadgroup_barrier(mem_flags::mem_threadgroup);
316+
if (ith%16 == 0) {
317+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
318+
}
319+
threadgroup_barrier(mem_flags::mem_threadgroup);
315320
if (ith == 0) {
321+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
316322
dst[r1*ne0 + r0] = sum[0];
317323
}
324+
325+
//// accumulate the sum from all threads in the threadgroup
326+
//threadgroup_barrier(mem_flags::mem_threadgroup);
327+
//for (uint i = nth/2; i > 0; i /= 2) {
328+
// if (ith < i) {
329+
// sum[ith] += sum[ith + i];
330+
// }
331+
// threadgroup_barrier(mem_flags::mem_threadgroup);
332+
//}
333+
334+
//if (ith == 0) {
335+
// dst[r1*ne0 + r0] = sum[0];
336+
//}
318337
}
319338

320339
kernel void kernel_mul_mat_f16_f32(
@@ -515,6 +534,13 @@ typedef struct {
515534
uint8_t qs[QK_K/2]; // 4--bit quants
516535
} block_q4_k;
517536

537+
typedef struct {
538+
uint8_t ql[QK_K/2]; // quants, lower 4 bits
539+
uint8_t qh[QK_K/4]; // quants, upper 2 bits
540+
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
541+
half d; // super-block scale
542+
} block_q6_k;
543+
518544
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
519545
uchar4 r;
520546
if (j < 4) {
@@ -554,6 +580,38 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
554580
}
555581
}
556582

583+
static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
584+
assert(k % QK_K == 0);
585+
const int nb = k / QK_K;
586+
587+
for (int i = 0; i < nb; i++) {
588+
589+
const float d = x[i].d;
590+
591+
device const uint8_t * ql = x[i].ql;
592+
device const uint8_t * qh = x[i].qh;
593+
device const int8_t * sc = x[i].scales;
594+
595+
for (int n = 0; n < QK_K; n += 128) {
596+
for (int l = 0; l < 32; ++l) {
597+
int is = l/16;
598+
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
599+
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
600+
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
601+
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
602+
y[l + 0] = d * sc[is + 0] * q1;
603+
y[l + 32] = d * sc[is + 2] * q2;
604+
y[l + 64] = d * sc[is + 4] * q3;
605+
y[l + 96] = d * sc[is + 6] * q4;
606+
}
607+
y += 128;
608+
ql += 64;
609+
qh += 32;
610+
sc += 8;
611+
}
612+
}
613+
}
614+
557615
kernel void kernel_get_rows_q4_k(
558616
device const void * src0,
559617
device const int * src1,
@@ -665,3 +723,108 @@ kernel void kernel_mul_mat_q4_k_f32(
665723
// dst[r1*ne0 + r0] = sum[0];
666724
//}
667725
}
726+
727+
kernel void kernel_get_rows_q6_k(
728+
device const void * src0,
729+
device const int * src1,
730+
device float * dst,
731+
constant int64_t & ne00,
732+
constant uint64_t & nb01,
733+
constant uint64_t & nb1,
734+
uint tpig[[thread_position_in_grid]]) {
735+
const int i = tpig;
736+
const int r = ((device int32_t *) src1)[i];
737+
738+
dequantize_row_q6_k(
739+
(device const block_q6_k *) ((device char *) src0 + r*nb01),
740+
(device float *) ((device char *) dst + i*nb1), ne00);
741+
}
742+
743+
kernel void kernel_mul_mat_q6_k_f32(
744+
device const void * src0,
745+
device const float * src1,
746+
device float * dst,
747+
constant int64_t & ne00,
748+
constant int64_t & ne01,
749+
constant uint64_t & nb00,
750+
constant uint64_t & nb01,
751+
constant uint64_t & nb02,
752+
constant int64_t & ne10,
753+
constant int64_t & ne11,
754+
constant uint64_t & nb10,
755+
constant uint64_t & nb11,
756+
constant uint64_t & nb12,
757+
constant int64_t & ne0,
758+
constant int64_t & ne1,
759+
threadgroup float * sum [[threadgroup(0)]],
760+
uint2 tgpig[[threadgroup_position_in_grid]],
761+
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
762+
uint2 tpitg[[thread_position_in_threadgroup]],
763+
uint2 tptg[[threads_per_threadgroup]]) {
764+
765+
const uint8_t kmask1 = 0x03;
766+
const uint8_t kmask2 = 0x0C;
767+
const uint8_t kmask3 = 0x30;
768+
const uint8_t kmask4 = 0xC0;
769+
770+
const int nb = ne00/QK_K;
771+
772+
const int64_t r0 = tgpig.x;
773+
const int64_t r1 = tgpig.y;
774+
775+
device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
776+
device const float * yy = (device const float *) src1 + r1*ne10;
777+
778+
const uint nth = tptg.x*tptg.y;
779+
const uint ith = tptg.y*tpitg.x + tpitg.y;
780+
781+
const int step = QK_K / tptg.y; // we expect this to be 16
782+
const int iqs = step * tpitg.y; // 0...240 in steps of 16
783+
const int ip = iqs / 128; // 0 or 1
784+
const int il = (iqs - 128*ip)/16; // 0...7
785+
const int n = 4;
786+
const int is = 8*ip + (n*il)/16;
787+
788+
float sumf = 0;
789+
for (int i = tpitg.x; i < nb; i += tptg.x) {
790+
791+
device const uint8_t * ql = x[i].ql + 64*ip + n*il;
792+
device const uint8_t * qh = x[i].qh + 32*ip + n*il;
793+
device const int8_t * sc = x[i].scales + is;
794+
795+
device const float * y = yy + i * QK_K + 128*ip + n*il;
796+
797+
const float dall = x[i].d;
798+
799+
float4 sums = {0.f, 0.f, 0.f, 0.f};
800+
for (int l = 0; l < n; ++l) {
801+
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
802+
sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
803+
sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
804+
sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
805+
}
806+
807+
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
808+
809+
}
810+
811+
sum[ith] = sumf;
812+
813+
//
814+
// Accumulate the sum from all threads in the threadgroup
815+
//
816+
threadgroup_barrier(mem_flags::mem_threadgroup);
817+
if (ith%4 == 0) {
818+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
819+
}
820+
threadgroup_barrier(mem_flags::mem_threadgroup);
821+
if (ith%16 == 0) {
822+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
823+
}
824+
threadgroup_barrier(mem_flags::mem_threadgroup);
825+
if (ith == 0) {
826+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
827+
dst[r1*ne0 + r0] = sum[0];
828+
}
829+
830+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.