Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e9b66ee

Browse filesBrowse files
ikawrakowKawrakow
andauthored
metal : add Q4_1 implementation (abetlen#1785)
23.3 ms / token, so just ~1% slower than q4_0. Achieves 290 GB/s memory throughput. Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 4f0154b commit e9b66ee
Copy full SHA for e9b66ee

File tree

Expand file treeCollapse file tree

2 files changed

+138
-1
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+138
-1
lines changed

‎ggml-metal.m

Copy file name to clipboardExpand all lines: ggml-metal.m
+15-1Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@
5050
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5151
GGML_METAL_DECL_KERNEL(get_rows_f16);
5252
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53+
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
5354
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
5455
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
5556
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5657
GGML_METAL_DECL_KERNEL(rms_norm);
5758
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5859
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
60+
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
5961
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
6062
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
6163
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
@@ -141,12 +143,14 @@
141143
GGML_METAL_ADD_KERNEL(diag_mask_inf);
142144
GGML_METAL_ADD_KERNEL(get_rows_f16);
143145
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
146+
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
144147
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
145148
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
146149
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
147150
GGML_METAL_ADD_KERNEL(rms_norm);
148151
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
149152
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
153+
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
150154
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
151155
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
152156
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
@@ -545,6 +549,15 @@ void ggml_metal_graph_compute(
545549
nth1 = 8;
546550
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
547551
} break;
552+
case GGML_TYPE_Q4_1:
553+
{
554+
GGML_ASSERT(ne02 == 1);
555+
GGML_ASSERT(ne12 == 1);
556+
557+
nth0 = 8;
558+
nth1 = 8;
559+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
560+
} break;
548561
case GGML_TYPE_Q2_K:
549562
{
550563
GGML_ASSERT(ne02 == 1);
@@ -596,7 +609,7 @@ void ggml_metal_graph_compute(
596609
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
597610
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
598611

599-
if (src0t == GGML_TYPE_Q4_0) {
612+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
600613
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
601614
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
602615
} else if (src0t == GGML_TYPE_Q2_K) {
@@ -623,6 +636,7 @@ void ggml_metal_graph_compute(
623636
switch (src0->type) {
624637
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
625638
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
639+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
626640
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
627641
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
628642
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;

‎ggml-metal.metal

Copy file name to clipboardExpand all lines: ggml-metal.metal
+123Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ typedef struct {
1111
uint8_t qs[QK4_0 / 2]; // nibbles / quants
1212
} block_q4_0;
1313

14+
#define QK4_1 32
15+
typedef struct {
16+
half d; // delta
17+
half m; // min
18+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
19+
} block_q4_1;
20+
1421
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
1522
const int qk = QK4_0;
1623

@@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
3138
}
3239
}
3340

41+
static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42+
const int qk = QK4_1;
43+
44+
assert(k % qk == 0);
45+
46+
const int nb = k / qk;
47+
48+
for (int i = 0; i < nb; i++) {
49+
const half d = x[i].d;
50+
const half m = x[i].m;
51+
52+
for (int j = 0; j < qk/2; ++j) {
53+
const int x0 = (x[i].qs[j] & 0x0F);
54+
const int x1 = (x[i].qs[j] >> 4);
55+
56+
y[i*qk + j + 0 ] = x0*d + m;
57+
y[i*qk + j + qk/2] = x1*d + m;
58+
}
59+
}
60+
}
61+
3462
kernel void kernel_add(
3563
device const float * src0,
3664
device const float * src1,
@@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0(
212240
(device float *) ((device char *) dst + i*nb1), ne00);
213241
}
214242

243+
kernel void kernel_get_rows_q4_1(
244+
device const void * src0,
245+
device const int * src1,
246+
device float * dst,
247+
constant int64_t & ne00,
248+
constant uint64_t & nb01,
249+
constant uint64_t & nb1,
250+
uint tpig[[thread_position_in_grid]]) {
251+
const int i = tpig;
252+
const int r = ((device int32_t *) src1)[i];
253+
254+
dequantize_row_q4_1(
255+
(device const block_q4_1 *) ((device char *) src0 + r*nb01),
256+
(device float *) ((device char *) dst + i*nb1), ne00);
257+
}
258+
215259
kernel void kernel_rms_norm(
216260
device const void * src0,
217261
device float * dst,
@@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32(
350394
//}
351395
}
352396

397+
kernel void kernel_mul_mat_q4_1_f32(
398+
device const void * src0,
399+
device const float * src1,
400+
device float * dst,
401+
constant int64_t & ne00,
402+
constant int64_t & ne01,
403+
constant uint64_t & nb00,
404+
constant uint64_t & nb01,
405+
constant uint64_t & nb02,
406+
constant int64_t & ne10,
407+
constant int64_t & ne11,
408+
constant uint64_t & nb10,
409+
constant uint64_t & nb11,
410+
constant uint64_t & nb12,
411+
constant int64_t & ne0,
412+
constant int64_t & ne1,
413+
threadgroup float * sum [[threadgroup(0)]],
414+
uint2 tgpig[[threadgroup_position_in_grid]],
415+
uint2 tpig[[thread_position_in_grid]],
416+
uint2 tpitg[[thread_position_in_threadgroup]],
417+
uint2 tptg[[threads_per_threadgroup]]) {
418+
const int nb = ne00/QK4_1;
419+
420+
const int64_t r0 = tgpig.x;
421+
const int64_t r1 = tgpig.y;
422+
423+
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
424+
device const float * y = (device const float *) src1 + r1*ne10;
425+
426+
const uint nth = tptg.x*tptg.y;
427+
const uint ith = tptg.y*tpitg.x + tpitg.y;
428+
429+
const int ix = tpitg.y/4; // 0 or 1
430+
const int iy = tpitg.y - 4*ix; // 0...3
431+
432+
const int first = 4 * iy;
433+
434+
float sumf = 0;
435+
436+
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
437+
438+
const float d = (float)x[i].d;
439+
const float m = (float)x[i].m;
440+
441+
device const uint8_t * xl = x[i].qs + first;
442+
device const float * yl = y + i * QK4_1 + first;
443+
444+
float2 acc = {0.0f, 0.0f};
445+
446+
for (int j = 0; j < 4; ++j) {
447+
448+
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
449+
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
450+
451+
}
452+
453+
sumf += acc[0] + acc[1];
454+
}
455+
456+
sum[ith] = sumf;
457+
458+
//
459+
// Accumulate the sum from all threads in the threadgroup
460+
//
461+
threadgroup_barrier(mem_flags::mem_threadgroup);
462+
if (ith%4 == 0) {
463+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
464+
}
465+
threadgroup_barrier(mem_flags::mem_threadgroup);
466+
if (ith%16 == 0) {
467+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
468+
}
469+
threadgroup_barrier(mem_flags::mem_threadgroup);
470+
if (ith == 0) {
471+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
472+
dst[r1*ne0 + r0] = sum[0];
473+
}
474+
}
475+
353476
kernel void kernel_mul_mat_f16_f32(
354477
device const char * src0,
355478
device const char * src1,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.