Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

metal : refactor kernel args into structs #10238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cont : mul mm id
ggml-ci
  • Loading branch information
ggerganov committed Nov 17, 2024
commit ec18f96891ccdaabfe42c21d440591632679be8b
18 changes: 18 additions & 0 deletions 18 ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,24 @@ typedef struct {
int16_t r3;
} ggml_metal_kargs_mul_mv;

typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
} ggml_metal_kargs_mul_mm_id;

typedef struct {
int32_t nei0;
int32_t nei1;
Expand Down
43 changes: 23 additions & 20 deletions 43 ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2301,27 +2301,30 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("MUL_MAT_ID not implemented");
}

ggml_metal_kargs_mul_mm_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
/*.nbi1 =*/ nb21,
/*.ne00 =*/ ne00,
/*.ne02 =*/ ne02,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
};

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];

[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];

Expand Down
153 changes: 78 additions & 75 deletions 153 ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5769,31 +5769,32 @@ kernel void kernel_mul_mm(
}

// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
// TODO: this kernel needs to be reimplemented from scratch for better performance
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
int32_t ne00,
int32_t ne02,
uint64_t nb01,
uint64_t nb02,
int32_t ne11,
int32_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
int32_t ne0,
int32_t ne1,
int64_t ne0ne1,
device const char * src0,
device const char * src1,
threadgroup ushort2 * rowids,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
int64_t ne1,
int64_t ne0ne1,
threadgroup uchar * shared_memory,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {

threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
device char * dst,
threadgroup char * shmem,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {

threadgroup half * sa = (threadgroup half *)(shmem);
threadgroup float * sb = (threadgroup float *)(shmem + 4096);

const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
Expand All @@ -5810,9 +5811,9 @@ void kernel_mul_mm_id_impl(

simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8];
simdgroup_float8x8 mc[8];
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
short il = (tiitg % THREAD_PER_ROW);

Expand Down Expand Up @@ -5850,41 +5851,57 @@ void kernel_mul_mm_id_impl(
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));

#pragma unroll(BLOCK_SIZE_K/8)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
#pragma unroll(2)
for (int i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}

lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;

#pragma unroll(8)
for (int i = 0; i < 8; i++){
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
}
}

{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}

threadgroup_barrier(mem_flags::mem_threadgroup);

device float * C = dst + (BLOCK_SIZE_M * r0);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
for (int i = 0; i < n_rows; i++) {
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;

device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
device float4 * D4 = (device float4 *) D;

threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
threadgroup float4 * C4 = (threadgroup float4 *) C;

int i = 0;
for (; i < n_rows/4; i++) {
*(D4 + i) = *(C4 + i);
}

i *= 4;
for (; i < n_rows; i++) {
*(D + i) = *(C + i);
}
}
}
Expand All @@ -5893,48 +5910,34 @@ void kernel_mul_mm_id_impl(

template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
device const uchar * src0s,
device const uchar * src1,
device float * dst,
device const uchar * ids,
constant int64_t & nei0,
constant int64_t & nei1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0s,
device const char * src1,
device char * dst,
device const char * ids,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {

const int32_t i02 = tgpig.z;

tgpig.z = 0;

device const uchar * src0 = src0s + i02*nb02;
device const char * src0 = src0s + i02*args.nb02;

// row indices
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);

// TODO: parallelize this loop
int64_t _ne1 = 0;
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
if (id == i02) {
//if (tiitg == 0) {
if (tiitg == 0) {
rowids[_ne1] = ushort2(ii0, ii1);
//}
}
_ne1++;
}
}
Expand All @@ -5943,23 +5946,23 @@ kernel void kernel_mul_mm_id(
threadgroup_barrier(mem_flags::mem_threadgroup);

kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
args.ne00,
args.ne02,
args.nb01,
args.nb02,
args.ne11,
args.ne12,
args.nb10,
args.nb11,
args.nb12,
args.ne0,
_ne1,
(int64_t)args.ne0*args.ne1,
src0,
src1,
rowids,
dst,
ne00,
ne02,
nb01,
nb02,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
_ne1,
ne0*ne1,
shared_memory,
shmem,
tgpig,
tiitg,
sgitg);
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.