Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 72ff528

Browse filesBrowse files
ikawrakowKawrakow
andauthored
metal : add Q2_K implementation (abetlen#1762)
* metal : add Q2_K implementation 27.1 ms / token on M2 Max 30-core GPU, so about the same speed as Q4_0. Memory throughput is ~156 GB/s. The access pattern used in the Q2_K CUDA implementation resulted in significantly lower performance (~31 ms/token). * Fixing merge conflicts --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 0bf7cf1 commit 72ff528
Copy full SHA for 72ff528

File tree

Expand file treeCollapse file tree

2 files changed

+200
-18
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+200
-18
lines changed

‎ggml-metal.m

Copy file name to clipboardExpand all lines: ggml-metal.m
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@
4949
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5050
GGML_METAL_DECL_KERNEL(get_rows_f16);
5151
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
52+
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
5253
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
5354
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5455
GGML_METAL_DECL_KERNEL(rms_norm);
5556
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5657
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
58+
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
5759
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
5860
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
5961
GGML_METAL_DECL_KERNEL(rope);
@@ -137,11 +139,13 @@
137139
GGML_METAL_ADD_KERNEL(diag_mask_inf);
138140
GGML_METAL_ADD_KERNEL(get_rows_f16);
139141
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
142+
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
140143
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
141144
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
142145
GGML_METAL_ADD_KERNEL(rms_norm);
143146
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
144147
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
148+
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
145149
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
146150
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
147151
GGML_METAL_ADD_KERNEL(rope);
@@ -525,6 +529,15 @@ void ggml_metal_graph_compute(
525529
nth1 = 4;
526530
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
527531
} break;
532+
case GGML_TYPE_Q2_K:
533+
{
534+
GGML_ASSERT(ne02 == 1);
535+
GGML_ASSERT(ne12 == 1);
536+
537+
nth0 = 4;
538+
nth1 = 16;
539+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
540+
} break;
528541
case GGML_TYPE_Q4_K:
529542
{
530543
GGML_ASSERT(ne02 == 1);
@@ -570,6 +583,9 @@ void ggml_metal_graph_compute(
570583
if (src0t == GGML_TYPE_Q4_0) {
571584
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
572585
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
586+
} else if (src0t == GGML_TYPE_Q2_K) {
587+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
588+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
573589
} else if (src0t == GGML_TYPE_Q4_K) {
574590
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
575591
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -591,6 +607,7 @@ void ggml_metal_graph_compute(
591607
switch (src0->type) {
592608
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
593609
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
610+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
594611
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
595612
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
596613
default: GGML_ASSERT(false && "not implemented");

‎ggml-metal.metal

Copy file name to clipboardExpand all lines: ggml-metal.metal
+183-18Lines changed: 183 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,13 @@ kernel void kernel_cpy_f32_f32(
527527

528528
#define QK_K 256
529529

530+
typedef struct {
531+
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
532+
uint8_t qs[QK_K/4]; // quants
533+
half d; // super-block scale for quantized scales
534+
half dmin; // super-block scale for quantized mins
535+
} block_q2_k;
536+
530537
typedef struct {
531538
half d; // super-block scale for quantized scales
532539
half dmin; // super-block scale for quantized mins
@@ -555,6 +562,41 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
555562
return r;
556563
}
557564

565+
//========================================== dequantization =============================
566+
567+
static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
568+
assert(k % QK_K == 0);
569+
const int nb = k / QK_K;
570+
571+
for (int i = 0; i < nb; i++) {
572+
573+
const float d = x[i].d;
574+
const float min = x[i].dmin;
575+
576+
device const uint8_t * q = x[i].qs;
577+
578+
int is = 0;
579+
float dl, ml;
580+
for (int n = 0; n < QK_K; n += 128) {
581+
int shift = 0;
582+
for (int j = 0; j < 4; ++j) {
583+
584+
uint8_t sc = x[i].scales[is++];
585+
dl = d * (sc & 0xF); ml = min * (sc >> 4);
586+
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
587+
588+
sc = x[i].scales[is++];
589+
dl = d * (sc & 0xF); ml = min * (sc >> 4);
590+
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
591+
592+
shift += 2;
593+
}
594+
q += 32;
595+
}
596+
597+
}
598+
}
599+
558600
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
559601
assert(k % QK_K == 0);
560602
const int nb = k / QK_K;
@@ -586,12 +628,12 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
586628

587629
for (int i = 0; i < nb; i++) {
588630

589-
const float d = x[i].d;
590-
591631
device const uint8_t * ql = x[i].ql;
592632
device const uint8_t * qh = x[i].qh;
593633
device const int8_t * sc = x[i].scales;
594634

635+
const float d = x[i].d;
636+
595637
for (int n = 0; n < QK_K; n += 128) {
596638
for (int l = 0; l < 32; ++l) {
597639
int is = l/16;
@@ -612,6 +654,22 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
612654
}
613655
}
614656

657+
kernel void kernel_get_rows_q2_k(
658+
device const void * src0,
659+
device const int * src1,
660+
device float * dst,
661+
constant int64_t & ne00,
662+
constant uint64_t & nb01,
663+
constant uint64_t & nb1,
664+
uint tpig[[thread_position_in_grid]]) {
665+
const int i = tpig;
666+
const int r = ((device int32_t *) src1)[i];
667+
668+
dequantize_row_q2_k(
669+
(device const block_q2_k *) ((device char *) src0 + r*nb01),
670+
(device float *) ((device char *) dst + i*nb1), ne00);
671+
}
672+
615673
kernel void kernel_get_rows_q4_k(
616674
device const void * src0,
617675
device const int * src1,
@@ -628,6 +686,129 @@ kernel void kernel_get_rows_q4_k(
628686
(device float *) ((device char *) dst + i*nb1), ne00);
629687
}
630688

689+
kernel void kernel_get_rows_q6_k(
690+
device const void * src0,
691+
device const int * src1,
692+
device float * dst,
693+
constant int64_t & ne00,
694+
constant uint64_t & nb01,
695+
constant uint64_t & nb1,
696+
uint tpig[[thread_position_in_grid]]) {
697+
const int i = tpig;
698+
const int r = ((device int32_t *) src1)[i];
699+
700+
dequantize_row_q6_k(
701+
(device const block_q6_k *) ((device char *) src0 + r*nb01),
702+
(device float *) ((device char *) dst + i*nb1), ne00);
703+
}
704+
705+
//====================================== dot products =========================
706+
707+
kernel void kernel_mul_mat_q2_k_f32(
708+
device const void * src0,
709+
device const float * src1,
710+
device float * dst,
711+
constant int64_t & ne00,
712+
constant int64_t & ne01,
713+
constant uint64_t & nb00,
714+
constant uint64_t & nb01,
715+
constant uint64_t & nb02,
716+
constant int64_t & ne10,
717+
constant int64_t & ne11,
718+
constant uint64_t & nb10,
719+
constant uint64_t & nb11,
720+
constant uint64_t & nb12,
721+
constant int64_t & ne0,
722+
constant int64_t & ne1,
723+
threadgroup float * sum [[threadgroup(0)]],
724+
uint2 tgpig[[threadgroup_position_in_grid]],
725+
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
726+
uint2 tpitg[[thread_position_in_threadgroup]],
727+
uint2 tptg[[threads_per_threadgroup]]) {
728+
729+
const int nb = ne00/QK_K;
730+
731+
const int64_t r0 = tgpig.x;
732+
const int64_t r1 = tgpig.y;
733+
734+
device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
735+
device const float * yy = (device const float *) src1 + r1*ne10;
736+
737+
const int nth = tptg.x*tptg.y;
738+
const int ith = tptg.y*tpitg.x + tpitg.y;
739+
740+
741+
const int tid = tpitg.y; // 0...16
742+
const int il = tid/4; // 0...3
743+
const int ir = tid%4; // 0...3
744+
const int ip = il/2; // 0 or 1
745+
const int shift1 = 4*(il%2);// 0 or 4
746+
const int shift2 = shift1+2;// 2 or 6
747+
const int n = 8;
748+
const int is = 4*il + (n*ir)/16;
749+
750+
sum[ith] = 0.0f;
751+
752+
float sumf = 0;
753+
for (int i = tpitg.x; i < nb; i += tptg.x) {
754+
755+
device const uint8_t * q = x[i].qs + 32*ip + n*ir;
756+
device const uint8_t * scales = x[i].scales + is;
757+
758+
uint8_t d1 = scales[0] & 0xF;
759+
uint8_t m1 = scales[0] >> 4;
760+
uint8_t d2 = scales[2] & 0xF;
761+
uint8_t m2 = scales[2] >> 4;
762+
763+
device const float * y = yy + i*QK_K + 64*il + n*ir;
764+
765+
const float dall = (float)x[i].d;
766+
const float dmin = (float)x[i].dmin;
767+
768+
float4 s = {0.f, 0.f, 0.f, 0.f};
769+
for (int l = 0; l < n; ++l) {
770+
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
771+
s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
772+
}
773+
sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
774+
775+
776+
}
777+
sum[ith] = sumf;
778+
779+
//
780+
// Accumulate the sum from all threads in the threadgroup
781+
// This version is slightly faster than the commented out one below,
782+
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
783+
//
784+
threadgroup_barrier(mem_flags::mem_threadgroup);
785+
if (ith%4 == 0) {
786+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
787+
}
788+
threadgroup_barrier(mem_flags::mem_threadgroup);
789+
if (ith%16 == 0) {
790+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
791+
}
792+
threadgroup_barrier(mem_flags::mem_threadgroup);
793+
if (ith == 0) {
794+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
795+
dst[r1*ne0 + r0] = sum[0];
796+
}
797+
798+
//// accumulate the sum from all threads in the threadgroup
799+
//threadgroup_barrier(mem_flags::mem_threadgroup);
800+
//for (uint i = nth/2; i > 0; i /= 2) {
801+
// if (ith < i) {
802+
// sum[ith] += sum[ith + i];
803+
// }
804+
// threadgroup_barrier(mem_flags::mem_threadgroup);
805+
//}
806+
807+
//if (ith == 0) {
808+
// dst[r1*ne0 + r0] = sum[0];
809+
//}
810+
}
811+
631812
kernel void kernel_mul_mat_q4_k_f32(
632813
device const void * src0,
633814
device const float * src1,
@@ -724,22 +905,6 @@ kernel void kernel_mul_mat_q4_k_f32(
724905
//}
725906
}
726907

727-
kernel void kernel_get_rows_q6_k(
728-
device const void * src0,
729-
device const int * src1,
730-
device float * dst,
731-
constant int64_t & ne00,
732-
constant uint64_t & nb01,
733-
constant uint64_t & nb1,
734-
uint tpig[[thread_position_in_grid]]) {
735-
const int i = tpig;
736-
const int r = ((device int32_t *) src1)[i];
737-
738-
dequantize_row_q6_k(
739-
(device const block_q6_k *) ((device char *) src0 + r*nb01),
740-
(device float *) ((device char *) dst + i*nb1), ne00);
741-
}
742-
743908
kernel void kernel_mul_mat_q6_k_f32(
744909
device const void * src0,
745910
device const float * src1,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.