Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 245fc3c

Browse filesBrowse files
ikawrakowKawrakow
andauthored
metal : faster q4_0 (abetlen#1775)
* metal : 8% faster q4_0 Avoid copying into local uchar4 anf float4. * metal : 17% faster Q4_0 Use 64 threads in a thread group. --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 72ff528 commit 245fc3c
Copy full SHA for 245fc3c

File tree

Expand file treeCollapse file tree

2 files changed

+20
-16
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+20
-16
lines changed

‎ggml-metal.m

Copy file name to clipboardExpand all lines: ggml-metal.m
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ void ggml_metal_graph_compute(
526526
GGML_ASSERT(ne12 == 1);
527527

528528
nth0 = 8;
529-
nth1 = 4;
529+
nth1 = 8;
530530
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
531531
} break;
532532
case GGML_TYPE_Q2_K:

‎ggml-metal.metal

Copy file name to clipboardExpand all lines: ggml-metal.metal
+19-15Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ kernel void kernel_mul_mat_q4_0_f32(
267267
uint2 tptg[[threads_per_threadgroup]]) {
268268
const int nb = ne00/QK4_0;
269269

270+
const int8_t m8 = 8;
271+
270272
const int64_t r0 = tgpig.x;
271273
const int64_t r1 = tgpig.y;
272274

@@ -276,33 +278,34 @@ kernel void kernel_mul_mat_q4_0_f32(
276278
const uint nth = tptg.x*tptg.y;
277279
const uint ith = tptg.y*tpitg.x + tpitg.y;
278280

279-
sum[ith] = 0.0f;
281+
const int ix = tpitg.y/4; // 0 or 1
282+
const int iy = tpitg.y - 4*ix; // 0...3
280283

281-
for (int i = tpitg.x; i < nb; i += tptg.x) {
282-
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
283-
device const float4 * y0p = (device const float4 *) (y + i*QK4_0);
284+
const int first = 4 * iy;
285+
286+
float sumf = 0;
284287

285-
const float d = (float)((x + i)->d);
288+
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
286289

287-
const uchar4 x0v = *(x0p + tpitg.y);
288-
const float4 y0v = *(y0p + tpitg.y + 0);
289-
const float4 y1v = *(y0p + tpitg.y + 4);
290+
const float d = (float)x[i].d;
290291

291-
float acc = 0.0f;
292+
device const uint8_t * xl = x[i].qs + first;
293+
device const float * yl = y + i * QK4_0 + first;
294+
295+
float2 acc = {0.0f, 0.0f};
292296

293297
for (int j = 0; j < 4; ++j) {
294-
const int x0 = x0v[j] & 0x0F;
295-
const int x1 = x0v[j] >> 4;
296298

297-
const float y0 = y0v[j];
298-
const float y1 = y1v[j];
299+
acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8);
300+
acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8);
299301

300-
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
301302
}
302303

303-
sum[ith] += acc*d;
304+
sumf += d * (acc[0] + acc[1]);
304305
}
305306

307+
sum[ith] = sumf;
308+
306309
//
307310
// Accumulate the sum from all threads in the threadgroup
308311
// This version is slightly faster than the commented out one below,
@@ -357,6 +360,7 @@ kernel void kernel_mul_mat_f16_f32(
357360
uint3 tpig[[thread_position_in_grid]],
358361
uint3 tpitg[[thread_position_in_threadgroup]],
359362
uint3 tptg[[threads_per_threadgroup]]) {
363+
360364
const int64_t r0 = tgpig.x;
361365
const int64_t r1 = tgpig.y;
362366
const int64_t im = tgpig.z;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.