Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 4161bdc

Browse filesBrowse files
ikawrakowKawrakow
andauthored
metal : add Q4_K implementation (abetlen#1733)
* Metal implementation for Q4_K Very slow for now: 42 ms / token, Q4_0 runs in 28 ms/token on my 30-core M2 Max GPU. * Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. * Optimizing q4_K metal dot some more For n = 256 it is now 28.1 ms/token compared to 27 ms/token for q4_0. * Fix after merge with master --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 0035858 commit 4161bdc
Copy full SHA for 4161bdc

File tree

Expand file treeCollapse file tree

3 files changed

+184
-19
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+184
-19
lines changed

‎.clang-tidy

Copy file name to clipboardExpand all lines: .clang-tidy
-18Lines changed: 0 additions & 18 deletions
This file was deleted.

‎ggml-metal.m

Copy file name to clipboardExpand all lines: ggml-metal.m
+22-1Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@
4949
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5050
GGML_METAL_DECL_KERNEL(get_rows_f16);
5151
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
52+
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
5253
GGML_METAL_DECL_KERNEL(rms_norm);
5354
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5455
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
56+
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
5557
GGML_METAL_DECL_KERNEL(rope);
5658
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
5759
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -133,9 +135,11 @@
133135
GGML_METAL_ADD_KERNEL(diag_mask_inf);
134136
GGML_METAL_ADD_KERNEL(get_rows_f16);
135137
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
138+
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
136139
GGML_METAL_ADD_KERNEL(rms_norm);
137140
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
138141
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
142+
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
139143
GGML_METAL_ADD_KERNEL(rope);
140144
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
141145
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -517,7 +521,20 @@ void ggml_metal_graph_compute(
517521
nth1 = 4;
518522
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
519523
} break;
520-
default: GGML_ASSERT(false && "not implemented");
524+
case GGML_TYPE_Q4_K:
525+
{
526+
GGML_ASSERT(ne02 == 1);
527+
GGML_ASSERT(ne12 == 1);
528+
529+
nth0 = 4;
530+
nth1 = 16;
531+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
532+
} break;
533+
default:
534+
{
535+
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
536+
GGML_ASSERT(false && "not implemented");
537+
}
521538
};
522539

523540

@@ -540,6 +557,9 @@ void ggml_metal_graph_compute(
540557
if (src0t == GGML_TYPE_Q4_0) {
541558
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
542559
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
560+
} else if (src0t == GGML_TYPE_Q4_K) {
561+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
562+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
543563
} else {
544564
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
545565
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -555,6 +575,7 @@ void ggml_metal_graph_compute(
555575
switch (src0->type) {
556576
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
557577
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
578+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
558579
default: GGML_ASSERT(false && "not implemented");
559580
}
560581

‎ggml-metal.metal

Copy file name to clipboardExpand all lines: ggml-metal.metal
+162Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,165 @@ kernel void kernel_cpy_f32_f32(
503503
dst_data[i00] = src[0];
504504
}
505505
}
506+
507+
//============================================ k-quants ======================================================
508+
509+
#define QK_K 256
510+
511+
typedef struct {
512+
half d; // super-block scale for quantized scales
513+
half dmin; // super-block scale for quantized mins
514+
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
515+
uint8_t qs[QK_K/2]; // 4--bit quants
516+
} block_q4_k;
517+
518+
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
519+
uchar4 r;
520+
if (j < 4) {
521+
r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
522+
r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
523+
} else {
524+
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
525+
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
526+
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
527+
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
528+
}
529+
return r;
530+
}
531+
532+
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
533+
assert(k % QK_K == 0);
534+
const int nb = k / QK_K;
535+
536+
for (int i = 0; i < nb; i++) {
537+
538+
const float d = x[i].d;
539+
const float min = x[i].dmin;
540+
541+
device const uint8_t * q = x[i].qs;
542+
device const uint8_t * scales = x[i].scales;
543+
544+
int is = 0;
545+
for (int j = 0; j < QK_K; j += 64) {
546+
const uchar4 sc = get_scale_min_k4(is, scales);
547+
const float d1 = d * sc[0]; const float m1 = min * sc[1];
548+
const float d2 = d * sc[2]; const float m2 = min * sc[3];
549+
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
550+
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
551+
q += 32; is += 2;
552+
}
553+
554+
}
555+
}
556+
557+
kernel void kernel_get_rows_q4_k(
558+
device const void * src0,
559+
device const int * src1,
560+
device float * dst,
561+
constant int64_t & ne00,
562+
constant uint64_t & nb01,
563+
constant uint64_t & nb1,
564+
uint tpig[[thread_position_in_grid]]) {
565+
const int i = tpig;
566+
const int r = ((device int32_t *) src1)[i];
567+
568+
dequantize_row_q4_k(
569+
(device const block_q4_k *) ((device char *) src0 + r*nb01),
570+
(device float *) ((device char *) dst + i*nb1), ne00);
571+
}
572+
573+
kernel void kernel_mul_mat_q4_k_f32(
574+
device const void * src0,
575+
device const float * src1,
576+
device float * dst,
577+
constant int64_t & ne00,
578+
constant int64_t & ne01,
579+
constant uint64_t & nb00,
580+
constant uint64_t & nb01,
581+
constant uint64_t & nb02,
582+
constant int64_t & ne10,
583+
constant int64_t & ne11,
584+
constant uint64_t & nb10,
585+
constant uint64_t & nb11,
586+
constant uint64_t & nb12,
587+
constant int64_t & ne0,
588+
constant int64_t & ne1,
589+
threadgroup float * sum [[threadgroup(0)]],
590+
uint2 tgpig[[threadgroup_position_in_grid]],
591+
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
592+
uint2 tpitg[[thread_position_in_threadgroup]],
593+
uint2 tptg[[threads_per_threadgroup]]) {
594+
595+
const int nb = ne00/QK_K;
596+
597+
const int64_t r0 = tgpig.x;
598+
const int64_t r1 = tgpig.y;
599+
600+
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
601+
device const float * yy = (device const float *) src1 + r1*ne10;
602+
603+
const uint nth = tptg.x*tptg.y;
604+
const uint ith = tptg.y*tpitg.x + tpitg.y;
605+
606+
const int tid = tpitg.y; // 0...16
607+
const int il = tid/4; // 0...3
608+
const int ir = tid%4; // 0...3
609+
const int n = 8;
610+
const int is = 2*il;
611+
612+
sum[ith] = 0.0f;
613+
614+
float sumf = 0;
615+
for (int i = tpitg.x; i < nb; i += tptg.x) {
616+
617+
device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
618+
device const float * y = yy + i*QK_K + 64*il + n*ir;
619+
device const uint8_t * scales = (x + i)->scales;
620+
621+
const float dall = (float)((x + i)->d);
622+
const float dmin = (float)((x + i)->dmin);
623+
624+
const uchar4 sc = get_scale_min_k4(is, scales);
625+
626+
float4 s = {0.f, 0.f, 0.f, 0.f};
627+
for (int l = 0; l < n; ++l) {
628+
s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
629+
s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
630+
}
631+
sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
632+
633+
}
634+
sum[ith] = sumf;
635+
636+
//
637+
// Accumulate the sum from all threads in the threadgroup
638+
// This version is slightly faster than the commented out one below,
639+
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
640+
//
641+
threadgroup_barrier(mem_flags::mem_threadgroup);
642+
if (ith%4 == 0) {
643+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
644+
}
645+
threadgroup_barrier(mem_flags::mem_threadgroup);
646+
if (ith%16 == 0) {
647+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
648+
}
649+
threadgroup_barrier(mem_flags::mem_threadgroup);
650+
if (ith == 0) {
651+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
652+
dst[r1*ne0 + r0] = sum[0];
653+
}
654+
655+
//// accumulate the sum from all threads in the threadgroup
656+
//threadgroup_barrier(mem_flags::mem_threadgroup);
657+
//for (uint i = nth/2; i > 0; i /= 2) {
658+
// if (ith < i) {
659+
// sum[ith] += sum[ith + i];
660+
// }
661+
// threadgroup_barrier(mem_flags::mem_threadgroup);
662+
//}
663+
664+
//if (ith == 0) {
665+
// dst[r1*ne0 + r0] = sum[0];
666+
//}
667+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.