Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 56904ca

Browse filesBrowse files
committed
Merge remote-tracking branch 'refs/remotes/upstream/master' into grammar
2 parents 98a9587 + fa84c4b commit 56904ca
Copy full SHA for 56904ca
Expand file treeCollapse file tree

18 files changed

+409
-77
lines changed

‎CMakeLists.txt

Copy file name to clipboardExpand all lines: CMakeLists.txt
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ target_link_libraries(llama PRIVATE
432432
if (BUILD_SHARED_LIBS)
433433
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
434434
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
435+
if (LLAMA_METAL)
436+
set_target_properties(llama PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
437+
endif()
435438
endif()
436439

437440
if (GGML_SOURCES_CUDA)

‎Makefile

Copy file name to clipboardExpand all lines: Makefile
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
107107
# Usage AVX-only
108108
#CFLAGS += -mfma -mf16c -mavx
109109
#CXXFLAGS += -mfma -mf16c -mavx
110+
111+
# Usage SSSE3-only (Not is SSE3!)
112+
#CFLAGS += -mssse3
113+
#CXXFLAGS += -mssse3
110114
endif
111115

112116
ifneq ($(filter ppc64%,$(UNAME_M)),)

‎README.md

Copy file name to clipboardExpand all lines: README.md
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ Building the program with BLAS support may lead to some performance improvements
308308

309309
- #### BLIS
310310

311-
Check [BLIS.md](BLIS.md) for more information.
311+
Check [BLIS.md](docs/BLIS.md) for more information.
312312

313313
- #### Intel MKL
314314

‎SHA256SUMS

Copy file name to clipboardExpand all lines: SHA256SUMS
+4-4Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
700df0d3013b703a806d2ae7f1bfb8e59814e3d06ae78be0c66368a50059f33d models/7B/consolidated.00.pth
22
666a4bb533b303bdaf89e1b6a3b6f93535d868de31d903afdc20983dc526c847 models/7B/ggml-model-f16.bin
3-
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/7B/ggml-model-q4_0.bin
3+
ec2f2d1f0dfb73b72a4cbac7fa121abbe04c37ab327125a38248f930c0f09ddf models/7B/ggml-model-q4_0.bin
44
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/7B/ggml-model-q4_1.bin
55
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/7B/ggml-model-q5_0.bin
66
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/7B/ggml-model-q5_1.bin
77
7e89e242ddc0dd6f060b43ca219ce8b3e8f08959a72cb3c0855df8bb04d46265 models/7B/params.json
88
745bf4e29a4dd6f411e72976d92b452da1b49168a4f41c951cfcc8051823cf08 models/13B/consolidated.00.pth
99
d5ccbcc465c71c0de439a5aeffebe8344c68a519bce70bc7f9f92654ee567085 models/13B/consolidated.01.pth
1010
2b206e9b21fb1076f11cafc624e2af97c9e48ea09312a0962153acc20d45f808 models/13B/ggml-model-f16.bin
11-
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/13B/ggml-model-q4_0.bin
11+
fad169e6f0f575402cf75945961cb4a8ecd824ba4da6be2af831f320c4348fa5 models/13B/ggml-model-q4_0.bin
1212
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/13B/ggml-model-q4_1.bin
1313
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/13B/ggml-model-q5_0.bin
1414
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/13B/ggml-model-q5_1.bin
@@ -18,7 +18,7 @@ e23294a58552d8cdec5b7e8abb87993b97ea6eced4178ff2697c02472539d067 models/30B/con
1818
24a87f01028cbd3a12de551dcedb712346c0b5cbdeff1454e0ddf2df9b675378 models/30B/consolidated.02.pth
1919
1adfcef71420886119544949767f6a56cb6339b4d5fcde755d80fe68b49de93b models/30B/consolidated.03.pth
2020
7e1b524061a9f4b27c22a12d6d2a5bf13b8ebbea73e99f218809351ed9cf7d37 models/30B/ggml-model-f16.bin
21-
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/30B/ggml-model-q4_0.bin
21+
d2a441403944819492ec8c2002cc36fa38468149bfb4b7b4c52afc7bd9a7166d models/30B/ggml-model-q4_0.bin
2222
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/30B/ggml-model-q4_1.bin
2323
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/30B/ggml-model-q5_0.bin
2424
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/30B/ggml-model-q5_1.bin
@@ -32,7 +32,7 @@ a287c0dfe49081626567c7fe87f74cce5831f58e459b427b5e05567641f47b78 models/65B/con
3232
72b4eba67a1a3b18cb67a85b70f8f1640caae9b40033ea943fb166bd80a7b36b models/65B/consolidated.06.pth
3333
d27f5b0677d7ff129ceacd73fd461c4d06910ad7787cf217b249948c3f3bc638 models/65B/consolidated.07.pth
3434
60758f2384d74e423dffddfd020ffed9d3bb186ebc54506f9c4a787d0f5367b0 models/65B/ggml-model-f16.bin
35-
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/65B/ggml-model-q4_0.bin
35+
cde053439fa4910ae454407e2717cc46cc2c2b4995c00c93297a2b52e790fa92 models/65B/ggml-model-q4_0.bin
3636
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/65B/ggml-model-q4_1.bin
3737
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/65B/ggml-model-q5_0.bin
3838
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff models/65B/ggml-model-q5_1.bin

‎examples/common.cpp

Copy file name to clipboardExpand all lines: examples/common.cpp
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,9 @@ void console_set_color(console_state & con_st, console_color_t color) {
656656
case CONSOLE_COLOR_USER_INPUT:
657657
fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN);
658658
break;
659+
case CONSOLE_COLOR_ERROR:
660+
fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_RED);
661+
break;
659662
}
660663
con_st.color = color;
661664
fflush(con_st.out);

‎examples/common.h

Copy file name to clipboardExpand all lines: examples/common.h
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params);
113113
enum console_color_t {
114114
CONSOLE_COLOR_DEFAULT=0,
115115
CONSOLE_COLOR_PROMPT,
116-
CONSOLE_COLOR_USER_INPUT
116+
CONSOLE_COLOR_USER_INPUT,
117+
CONSOLE_COLOR_ERROR
117118
};
118119

119120
struct console_state {

‎examples/main/main.cpp

Copy file name to clipboardExpand all lines: examples/main/main.cpp
+16Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ int main(int argc, char ** argv) {
8282
if (params.n_ctx > 2048) {
8383
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
8484
"expect poor results\n", __func__, params.n_ctx);
85+
} else if (params.n_ctx < 8) {
86+
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
87+
params.n_ctx = 8;
8588
}
8689

8790
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
@@ -347,6 +350,19 @@ int main(int argc, char ** argv) {
347350
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
348351
// predict
349352
if (embd.size() > 0) {
353+
// Note: n_ctx - 4 here is to match the logic for commandline prompt handling via
354+
// --prompt or --file which uses the same value.
355+
auto max_embd_size = n_ctx - 4;
356+
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
357+
if ((int)embd.size() > max_embd_size) {
358+
auto skipped_tokens = embd.size() - max_embd_size;
359+
console_set_color(con_st, CONSOLE_COLOR_ERROR);
360+
printf("<<input too long: skipped %ld token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
361+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
362+
fflush(stdout);
363+
embd.resize(max_embd_size);
364+
}
365+
350366
// infinite text generation via context swapping
351367
// if we run out of context:
352368
// - take the n_keep first tokens from the original prompt (via n_past)

‎examples/quantize/quantize.cpp

Copy file name to clipboardExpand all lines: examples/quantize/quantize.cpp
+38-19Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama.h"
44

55
#include <cstdio>
6+
#include <cstring>
67
#include <map>
78
#include <string>
89

@@ -53,27 +54,49 @@ bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::st
5354
// usage:
5455
// ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
5556
//
57+
void usage(const char * executable) {
58+
fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.bin [model-quant.bin] type [nthreads]\n", executable);
59+
fprintf(stderr, " --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
60+
fprintf(stderr, " --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
61+
fprintf(stderr, "Allowed quantization types:\n");
62+
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
63+
fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
64+
}
65+
exit(1);
66+
}
67+
5668
int main(int argc, char ** argv) {
5769
if (argc < 3) {
58-
fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
59-
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
60-
fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
70+
usage(argv[0]);
71+
}
72+
73+
llama_model_quantize_params params = llama_model_quantize_default_params();
74+
75+
int arg_idx = 1;
76+
77+
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
78+
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
79+
params.quantize_output_tensor = false;
80+
} else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
81+
params.allow_requantize = true;
82+
} else {
83+
usage(argv[0]);
6184
}
62-
return 1;
85+
}
86+
87+
if (argc - arg_idx < 3) {
88+
usage(argv[0]);
6389
}
6490

6591
llama_init_backend();
6692

6793
// parse command line arguments
68-
const std::string fname_inp = argv[1];
94+
const std::string fname_inp = argv[arg_idx];
95+
arg_idx++;
6996
std::string fname_out;
70-
int nthread;
71-
llama_ftype ftype;
7297

73-
int arg_idx = 2;
7498
std::string ftype_str;
75-
if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
76-
// argv[2] is the ftype
99+
if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
77100
std::string fpath;
78101
const size_t pos = fname_inp.find_last_of('/');
79102
if (pos != std::string::npos) {
@@ -84,16 +107,14 @@ int main(int argc, char ** argv) {
84107
arg_idx++;
85108
}
86109
else {
87-
// argv[2] is the output path
88110
fname_out = argv[arg_idx];
89111
arg_idx++;
90112

91113
if (argc <= arg_idx) {
92114
fprintf(stderr, "%s: missing ftype\n", __func__);
93115
return 1;
94116
}
95-
// argv[3] is the ftype
96-
if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
117+
if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
97118
fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
98119
return 1;
99120
}
@@ -103,21 +124,19 @@ int main(int argc, char ** argv) {
103124
// parse nthreads
104125
if (argc > arg_idx) {
105126
try {
106-
nthread = std::stoi(argv[arg_idx]);
127+
params.nthread = std::stoi(argv[arg_idx]);
107128
}
108129
catch (const std::exception & e) {
109130
fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
110131
return 1;
111132
}
112-
} else {
113-
nthread = 0;
114133
}
115134

116135
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
117136

118137
fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
119-
if (nthread > 0) {
120-
fprintf(stderr, " using %d threads", nthread);
138+
if (params.nthread > 0) {
139+
fprintf(stderr, " using %d threads", params.nthread);
121140
}
122141
fprintf(stderr, "\n");
123142

@@ -129,7 +148,7 @@ int main(int argc, char ** argv) {
129148
{
130149
const int64_t t_start_us = llama_time_us();
131150

132-
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) {
151+
if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), &params)) {
133152
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
134153
return 1;
135154
}

‎flake.nix

Copy file name to clipboardExpand all lines: flake.nix
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
postPatch =
2929
if isM1 then ''
3030
substituteInPlace ./ggml-metal.m \
31-
--replace '[[NSBundle mainBundle] pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/ggml-metal.metal\";"
31+
--replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/ggml-metal.metal\";"
3232
'' else "";
3333
nativeBuildInputs = with pkgs; [ cmake ];
3434
buildInputs = osSpecific;

‎ggml-cuda.cu

Copy file name to clipboardExpand all lines: ggml-cuda.cu
+12Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,9 @@ void * ggml_cuda_host_malloc(size_t size) {
11051105
void * ptr = nullptr;
11061106
cudaError_t err = cudaMallocHost((void **) &ptr, size);
11071107
if (err != cudaSuccess) {
1108+
// The allocation error can be bypassed. A null ptr will assigned out of this function.
1109+
// This can fixed the OOM error in WSL.
1110+
cudaGetLastError();
11081111
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
11091112
size/1024.0/1024.0, cudaGetErrorString(err));
11101113
return nullptr;
@@ -1512,6 +1515,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
15121515
i01_high = row_high % ne01;
15131516
}
15141517
}
1518+
1519+
// There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables.
1520+
// Removing the first assert or changing the order of the arguments causes the second assert to fail.
1521+
// Removing both asserts results in i01_high becoming 0 which in turn results in garbage output.
1522+
// The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU).
1523+
GGML_ASSERT(i01_low == 0 || g_device_count > 1);
1524+
GGML_ASSERT(i01_high == ne01 || g_device_count > 1);
1525+
15151526
const int64_t i01_diff = i01_high - i01_low;
15161527
if (i01_diff == 0) {
15171528
continue;
@@ -1727,6 +1738,7 @@ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const
17271738
row_low -= row_low % GGML_CUDA_DMMV_Y;
17281739
row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
17291740
row_high -= row_high % GGML_CUDA_DMMV_Y;
1741+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
17301742
} else {
17311743
GGML_ASSERT(false);
17321744
}

‎ggml-metal.m

Copy file name to clipboardExpand all lines: ggml-metal.m
+40-3Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,19 @@
4545
GGML_METAL_DECL_KERNEL(scale);
4646
GGML_METAL_DECL_KERNEL(silu);
4747
GGML_METAL_DECL_KERNEL(relu);
48+
GGML_METAL_DECL_KERNEL(gelu);
4849
GGML_METAL_DECL_KERNEL(soft_max);
4950
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5051
GGML_METAL_DECL_KERNEL(get_rows_f16);
5152
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53+
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
5254
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
5355
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
5456
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5557
GGML_METAL_DECL_KERNEL(rms_norm);
5658
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5759
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
60+
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
5861
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
5962
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
6063
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
@@ -70,6 +73,12 @@
7073
// for now it is easier to work in a separate file
7174
static NSString * const msl_library_source = @"see metal.metal";
7275

76+
// Here to assist with NSBundle Path Hack
77+
@interface GGMLMetalClass : NSObject
78+
@end
79+
@implementation GGMLMetalClass
80+
@end
81+
7382
struct ggml_metal_context * ggml_metal_init(void) {
7483
fprintf(stderr, "%s: allocating\n", __func__);
7584

@@ -105,7 +114,8 @@
105114
NSError * error = nil;
106115

107116
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
108-
NSString * path = [[NSBundle mainBundle] pathForResource:@"ggml-metal" ofType:@"metal"];
117+
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
118+
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
109119
fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
110120

111121
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@@ -135,16 +145,19 @@
135145
GGML_METAL_ADD_KERNEL(scale);
136146
GGML_METAL_ADD_KERNEL(silu);
137147
GGML_METAL_ADD_KERNEL(relu);
148+
GGML_METAL_ADD_KERNEL(gelu);
138149
GGML_METAL_ADD_KERNEL(soft_max);
139150
GGML_METAL_ADD_KERNEL(diag_mask_inf);
140151
GGML_METAL_ADD_KERNEL(get_rows_f16);
141152
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
153+
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
142154
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
143155
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
144156
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
145157
GGML_METAL_ADD_KERNEL(rms_norm);
146158
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
147159
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
160+
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
148161
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
149162
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
150163
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
@@ -420,6 +433,20 @@ void ggml_metal_graph_compute(
420433

421434
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
422435
} break;
436+
case GGML_OP_GELU:
437+
{
438+
if (encoder == nil) {
439+
encoder = [command_buffer computeCommandEncoder];
440+
}
441+
442+
[encoder setComputePipelineState:ctx->pipeline_gelu];
443+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
444+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
445+
446+
const int64_t n = ggml_nelements(dst);
447+
448+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
449+
} break;
423450
case GGML_OP_SOFT_MAX:
424451
{
425452
if (encoder == nil) {
@@ -526,9 +553,18 @@ void ggml_metal_graph_compute(
526553
GGML_ASSERT(ne12 == 1);
527554

528555
nth0 = 8;
529-
nth1 = 4;
556+
nth1 = 8;
530557
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
531558
} break;
559+
case GGML_TYPE_Q4_1:
560+
{
561+
GGML_ASSERT(ne02 == 1);
562+
GGML_ASSERT(ne12 == 1);
563+
564+
nth0 = 8;
565+
nth1 = 8;
566+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
567+
} break;
532568
case GGML_TYPE_Q2_K:
533569
{
534570
GGML_ASSERT(ne02 == 1);
@@ -580,7 +616,7 @@ void ggml_metal_graph_compute(
580616
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
581617
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
582618

583-
if (src0t == GGML_TYPE_Q4_0) {
619+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
584620
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
585621
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
586622
} else if (src0t == GGML_TYPE_Q2_K) {
@@ -607,6 +643,7 @@ void ggml_metal_graph_compute(
607643
switch (src0->type) {
608644
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
609645
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
646+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
610647
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
611648
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
612649
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.