Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ce21cc8

Browse filesBrowse files
committed
ggml : add cpy
ggml-ci
1 parent 18ddfbc commit ce21cc8
Copy full SHA for ce21cc8

File tree

Expand file treeCollapse file tree

3 files changed

+37
-139
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+37
-139
lines changed

‎common/common.cpp

Copy file name to clipboardExpand all lines: common/common.cpp
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2160,6 +2160,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
21602160
if (s == "f32") {
21612161
return GGML_TYPE_F32;
21622162
}
2163+
if (s == "bf16") {
2164+
return GGML_TYPE_BF16;
2165+
}
21632166
if (s == "f16") {
21642167
return GGML_TYPE_F16;
21652168
}

‎ggml/src/ggml-metal.m

Copy file name to clipboardExpand all lines: ggml/src/ggml-metal.m
+20-10Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,18 @@
199199
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
200200
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
201201
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
202-
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
203202
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
203+
GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
204+
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
205+
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
206+
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
207+
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
204208
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
205209
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
206210
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
207211
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
208212
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
209213
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
210-
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
211-
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
212214
GGML_METAL_KERNEL_TYPE_CONCAT,
213215
GGML_METAL_KERNEL_TYPE_SQR,
214216
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -661,16 +663,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
661663
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
662664
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
663665
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
666+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true);
664667
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
665668
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
669+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
670+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
671+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
666672
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
667673
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
668674
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
669675
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
670676
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
671677
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
672-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
673-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
674678
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
675679
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
676680
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -750,7 +754,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
750754
for (size_t i = 0, n = 3; i < n; ++i) {
751755
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
752756
op->op != GGML_OP_GET_ROWS &&
753-
op->op != GGML_OP_MUL_MAT) {
757+
op->op != GGML_OP_MUL_MAT &&
758+
op->op != GGML_OP_VIEW &&
759+
op->op != GGML_OP_CPY) {
754760
printf("op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
755761
GGML_ASSERT(false);
756762
}
@@ -826,6 +832,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
826832
case GGML_TYPE_F32:
827833
switch (op->type) {
828834
case GGML_TYPE_F16:
835+
case GGML_TYPE_BF16:
829836
case GGML_TYPE_F32:
830837
case GGML_TYPE_Q8_0:
831838
case GGML_TYPE_Q4_0:
@@ -840,6 +847,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
840847
case GGML_TYPE_F16:
841848
switch (op->type) {
842849
case GGML_TYPE_F16:
850+
case GGML_TYPE_BF16:
843851
case GGML_TYPE_F32:
844852
return true;
845853
default:
@@ -2812,8 +2820,9 @@ static enum ggml_status ggml_metal_graph_compute(
28122820
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
28132821

28142822
switch (dstt) {
2815-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
2816-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2823+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
2824+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
2825+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
28172826
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
28182827
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
28192828
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -2826,8 +2835,9 @@ static enum ggml_status ggml_metal_graph_compute(
28262835
case GGML_TYPE_F16:
28272836
{
28282837
switch (dstt) {
2829-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
2830-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2838+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
2839+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
2840+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
28312841
default: GGML_ASSERT(false && "not implemented");
28322842
};
28332843
} break;

‎ggml/src/ggml-metal.metal

Copy file name to clipboardExpand all lines: ggml/src/ggml-metal.metal
+14-129Lines changed: 14 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,91 +2597,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
25972597
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
25982598
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
25992599

2600-
kernel void kernel_cpy_f16_f16(
2601-
device const half * src0,
2602-
device half * dst,
2603-
constant int64_t & ne00,
2604-
constant int64_t & ne01,
2605-
constant int64_t & ne02,
2606-
constant int64_t & ne03,
2607-
constant uint64_t & nb00,
2608-
constant uint64_t & nb01,
2609-
constant uint64_t & nb02,
2610-
constant uint64_t & nb03,
2611-
constant int64_t & ne0,
2612-
constant int64_t & ne1,
2613-
constant int64_t & ne2,
2614-
constant int64_t & ne3,
2615-
constant uint64_t & nb0,
2616-
constant uint64_t & nb1,
2617-
constant uint64_t & nb2,
2618-
constant uint64_t & nb3,
2619-
uint3 tgpig[[threadgroup_position_in_grid]],
2620-
uint3 tpitg[[thread_position_in_threadgroup]],
2621-
uint3 ntg[[threads_per_threadgroup]]) {
2622-
const int64_t i03 = tgpig[2];
2623-
const int64_t i02 = tgpig[1];
2624-
const int64_t i01 = tgpig[0];
2625-
2626-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2627-
2628-
const int64_t i3 = n / (ne2*ne1*ne0);
2629-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2630-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2631-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2632-
2633-
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634-
2635-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2636-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2637-
dst_data[i00] = src[0];
2638-
}
2639-
}
2640-
2641-
kernel void kernel_cpy_f16_f32(
2642-
device const half * src0,
2643-
device float * dst,
2644-
constant int64_t & ne00,
2645-
constant int64_t & ne01,
2646-
constant int64_t & ne02,
2647-
constant int64_t & ne03,
2648-
constant uint64_t & nb00,
2649-
constant uint64_t & nb01,
2650-
constant uint64_t & nb02,
2651-
constant uint64_t & nb03,
2652-
constant int64_t & ne0,
2653-
constant int64_t & ne1,
2654-
constant int64_t & ne2,
2655-
constant int64_t & ne3,
2656-
constant uint64_t & nb0,
2657-
constant uint64_t & nb1,
2658-
constant uint64_t & nb2,
2659-
constant uint64_t & nb3,
2660-
uint3 tgpig[[threadgroup_position_in_grid]],
2661-
uint3 tpitg[[thread_position_in_threadgroup]],
2662-
uint3 ntg[[threads_per_threadgroup]]) {
2663-
const int64_t i03 = tgpig[2];
2664-
const int64_t i02 = tgpig[1];
2665-
const int64_t i01 = tgpig[0];
2666-
2667-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2668-
2669-
const int64_t i3 = n / (ne2*ne1*ne0);
2670-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2671-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2672-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2673-
2674-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2675-
2676-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2677-
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2678-
dst_data[i00] = src[0];
2679-
}
2680-
}
2681-
2682-
kernel void kernel_cpy_f32_f16(
2683-
device const float * src0,
2684-
device half * dst,
2600+
template<typename T0, typename T1>
2601+
kernel void kernel_cpy(
2602+
device const void * src0,
2603+
device void * dst,
26852604
constant int64_t & ne00,
26862605
constant int64_t & ne01,
26872606
constant int64_t & ne02,
@@ -2712,56 +2631,22 @@ kernel void kernel_cpy_f32_f16(
27122631
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
27132632
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
27142633

2715-
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2634+
device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
27162635

27172636
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2718-
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2719-
2720-
dst_data[i00] = src[0];
2637+
device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2638+
dst_data[i00] = (T1) src[0];
27212639
}
27222640
}
27232641

2724-
kernel void kernel_cpy_f32_f32(
2725-
device const float * src0,
2726-
device float * dst,
2727-
constant int64_t & ne00,
2728-
constant int64_t & ne01,
2729-
constant int64_t & ne02,
2730-
constant int64_t & ne03,
2731-
constant uint64_t & nb00,
2732-
constant uint64_t & nb01,
2733-
constant uint64_t & nb02,
2734-
constant uint64_t & nb03,
2735-
constant int64_t & ne0,
2736-
constant int64_t & ne1,
2737-
constant int64_t & ne2,
2738-
constant int64_t & ne3,
2739-
constant uint64_t & nb0,
2740-
constant uint64_t & nb1,
2741-
constant uint64_t & nb2,
2742-
constant uint64_t & nb3,
2743-
uint3 tgpig[[threadgroup_position_in_grid]],
2744-
uint3 tpitg[[thread_position_in_threadgroup]],
2745-
uint3 ntg[[threads_per_threadgroup]]) {
2746-
const int64_t i03 = tgpig[2];
2747-
const int64_t i02 = tgpig[1];
2748-
const int64_t i01 = tgpig[0];
2749-
2750-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2751-
2752-
const int64_t i3 = n / (ne2*ne1*ne0);
2753-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2754-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2755-
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
2756-
2757-
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2642+
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
27582643

2759-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2760-
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2761-
2762-
dst_data[i00] = src[0];
2763-
}
2764-
}
2644+
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
2645+
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
2646+
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
2647+
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
2648+
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
2649+
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
27652650

27662651
kernel void kernel_cpy_f32_q8_0(
27672652
device const float * src0,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.