Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 527b6fb

Browse filesBrowse files
didzisggerganov
andauthored
llama : make model stateless and context stateful (llama_state) (ggml-org#1797)
* llama : make model stateless and context stateful * llama : minor cleanup * llama : update internal API declaration * Apply suggestions from code review fix style Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Missing model memory release * Fix style * Add deprecated warning for public API function llama_init_from_file * Update public API use cases: move away from deprecated llama_init_from_file * Deprecate public API function llama_apply_lora_from_file --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent d7b7484 commit 527b6fb
Copy full SHA for 527b6fb

File tree

Expand file treeCollapse file tree

13 files changed

+243
-91
lines changed
Filter options
Expand file treeCollapse file tree

13 files changed

+243
-91
lines changed

‎examples/common.cpp

Copy file name to clipboardExpand all lines: examples/common.cpp
+15-7Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
536536
return res;
537537
}
538538

539-
struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
539+
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
540540
auto lparams = llama_context_default_params();
541541

542542
lparams.n_ctx = params.n_ctx;
@@ -552,25 +552,33 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
552552
lparams.logits_all = params.perplexity;
553553
lparams.embedding = params.embedding;
554554

555-
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
555+
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
556+
if (model == NULL) {
557+
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
558+
return std::make_tuple(nullptr, nullptr);
559+
}
556560

561+
llama_context * lctx = llama_new_context_with_model(model, lparams);
557562
if (lctx == NULL) {
558-
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
559-
return NULL;
563+
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
564+
llama_free_model(model);
565+
return std::make_tuple(nullptr, nullptr);
560566
}
561567

562568
if (!params.lora_adapter.empty()) {
563-
int err = llama_apply_lora_from_file(lctx,
569+
int err = llama_model_apply_lora_from_file(model,
564570
params.lora_adapter.c_str(),
565571
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
566572
params.n_threads);
567573
if (err != 0) {
568574
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
569-
return NULL;
575+
llama_free(lctx);
576+
llama_free_model(model);
577+
return std::make_tuple(nullptr, nullptr);
570578
}
571579
}
572580

573-
return lctx;
581+
return std::make_tuple(model, lctx);
574582
}
575583

576584
void console_init(console_state & con_st) {

‎examples/common.h

Copy file name to clipboardExpand all lines: examples/common.h
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <random>
1010
#include <thread>
1111
#include <unordered_map>
12+
#include <tuple>
1213

1314
#if !defined (_WIN32)
1415
#include <stdio.h>
@@ -95,7 +96,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
9596
// Model utils
9697
//
9798

98-
struct llama_context * llama_init_from_gpt_params(const gpt_params & params);
99+
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
99100

100101
//
101102
// Console utils

‎examples/embedding/embedding.cpp

Copy file name to clipboardExpand all lines: examples/embedding/embedding.cpp
+4-2Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ int main(int argc, char ** argv) {
3737

3838
llama_init_backend();
3939

40+
llama_model * model;
4041
llama_context * ctx;
4142

4243
// load the model
43-
ctx = llama_init_from_gpt_params(params);
44-
if (ctx == NULL) {
44+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
45+
if (model == NULL) {
4546
fprintf(stderr, "%s: error: unable to load model\n", __func__);
4647
return 1;
4748
}
@@ -90,6 +91,7 @@ int main(int argc, char ** argv) {
9091

9192
llama_print_timings(ctx);
9293
llama_free(ctx);
94+
llama_free_model(model);
9395

9496
return 0;
9597
}

‎examples/main/main.cpp

Copy file name to clipboardExpand all lines: examples/main/main.cpp
+6-2Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,13 @@ int main(int argc, char ** argv) {
107107

108108
llama_init_backend();
109109

110+
llama_model * model;
110111
llama_context * ctx;
111112
g_ctx = &ctx;
112113

113114
// load the model and apply lora adapter, if any
114-
ctx = llama_init_from_gpt_params(params);
115-
if (ctx == NULL) {
115+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
116+
if (model == NULL) {
116117
fprintf(stderr, "%s: error: unable to load model\n", __func__);
117118
return 1;
118119
}
@@ -139,6 +140,7 @@ int main(int argc, char ** argv) {
139140

140141
llama_print_timings(ctx);
141142
llama_free(ctx);
143+
llama_free_model(model);
142144

143145
return 0;
144146
}
@@ -147,6 +149,7 @@ int main(int argc, char ** argv) {
147149
if (params.export_cgraph) {
148150
llama_eval_export(ctx, "llama.ggml");
149151
llama_free(ctx);
152+
llama_free_model(model);
150153

151154
return 0;
152155
}
@@ -666,6 +669,7 @@ int main(int argc, char ** argv) {
666669

667670
llama_print_timings(ctx);
668671
llama_free(ctx);
672+
llama_free_model(model);
669673

670674
return 0;
671675
}

‎examples/perplexity/perplexity.cpp

Copy file name to clipboardExpand all lines: examples/perplexity/perplexity.cpp
+4-2Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,12 @@ int main(int argc, char ** argv) {
149149

150150
llama_init_backend();
151151

152+
llama_model * model;
152153
llama_context * ctx;
153154

154155
// load the model and apply lora adapter, if any
155-
ctx = llama_init_from_gpt_params(params);
156-
if (ctx == NULL) {
156+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
157+
if (model == NULL) {
157158
fprintf(stderr, "%s: error: unable to load model\n", __func__);
158159
return 1;
159160
}
@@ -169,6 +170,7 @@ int main(int argc, char ** argv) {
169170

170171
llama_print_timings(ctx);
171172
llama_free(ctx);
173+
llama_free_model(model);
172174

173175
return 0;
174176
}

‎examples/quantize-stats/quantize-stats.cpp

Copy file name to clipboardExpand all lines: examples/quantize-stats/quantize-stats.cpp
+13-2Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ int main(int argc, char ** argv) {
320320
fprintf(stderr, "Loading model\n");
321321

322322
const int64_t t_main_start_us = ggml_time_us();
323+
llama_model * model;
323324
llama_context * ctx;
324325

325326
{
@@ -330,12 +331,20 @@ int main(int argc, char ** argv) {
330331
lparams.f16_kv = false;
331332
lparams.use_mlock = false;
332333

333-
ctx = llama_init_from_file(params.model.c_str(), lparams);
334+
model = llama_load_model_from_file(params.model.c_str(), lparams);
334335

335-
if (ctx == NULL) {
336+
if (model == NULL) {
336337
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
337338
return 1;
338339
}
340+
341+
ctx = llama_new_context_with_model(model, lparams);
342+
343+
if (ctx == NULL) {
344+
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
345+
llama_free_model(model);
346+
return 1;
347+
}
339348
}
340349

341350
const auto &tensors = llama_internal_get_tensor_map(ctx);
@@ -357,6 +366,7 @@ int main(int argc, char ** argv) {
357366
fprintf(stderr, "%s: error: Quantization should be tested with a float model, "
358367
"this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type);
359368
llama_free(ctx);
369+
llama_free_model(model);
360370
return 1;
361371
}
362372
included_layers++;
@@ -415,6 +425,7 @@ int main(int argc, char ** argv) {
415425

416426

417427
llama_free(ctx);
428+
llama_free_model(model);
418429
// report timing
419430
{
420431
const int64_t t_main_end_us = ggml_time_us();

‎examples/save-load-state/save-load-state.cpp

Copy file name to clipboardExpand all lines: examples/save-load-state/save-load-state.cpp
+25-4Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
3535
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
3636

3737
// init
38-
auto ctx = llama_init_from_file(params.model.c_str(), lparams);
38+
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
39+
if (model == nullptr) {
40+
return 1;
41+
}
42+
auto ctx = llama_new_context_with_model(model, lparams);
43+
if (ctx == nullptr) {
44+
llama_free_model(model);
45+
return 1;
46+
}
3947
auto tokens = std::vector<llama_token>(params.n_ctx);
4048
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
4149

4250
if (n_prompt_tokens < 1) {
4351
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
52+
llama_free(ctx);
53+
llama_free_model(model);
4454
return 1;
4555
}
4656

@@ -84,30 +94,36 @@ int main(int argc, char ** argv) {
8494
printf("%s", next_token_str);
8595
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
8696
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
97+
llama_free(ctx);
98+
llama_free_model(model);
8799
return 1;
88100
}
89101
n_past += 1;
90102
}
91103

92104
printf("\n\n");
93105

94-
// free old model
106+
// free old context
95107
llama_free(ctx);
96108

97-
// load new model
98-
auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
109+
// make new context
110+
auto ctx2 = llama_new_context_with_model(model, lparams);
99111

100112
// Load state (rng, logits, embedding and kv_cache) from file
101113
{
102114
FILE *fp_read = fopen("dump_state.bin", "rb");
103115
if (state_size != llama_get_state_size(ctx2)) {
104116
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
117+
llama_free(ctx2);
118+
llama_free_model(model);
105119
return 1;
106120
}
107121

108122
const size_t ret = fread(state_mem, 1, state_size, fp_read);
109123
if (ret != state_size) {
110124
fprintf(stderr, "\n%s : failed to read state\n", __func__);
125+
llama_free(ctx2);
126+
llama_free_model(model);
111127
return 1;
112128
}
113129

@@ -138,12 +154,17 @@ int main(int argc, char ** argv) {
138154
printf("%s", next_token_str);
139155
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
140156
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
157+
llama_free(ctx2);
158+
llama_free_model(model);
141159
return 1;
142160
}
143161
n_past += 1;
144162
}
145163

146164
printf("\n\n");
147165

166+
llama_free(ctx2);
167+
llama_free_model(model);
168+
148169
return 0;
149170
}

‎examples/server/server.cpp

Copy file name to clipboardExpand all lines: examples/server/server.cpp
+7-2Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct llama_server_context {
115115
std::vector<llama_token> embd;
116116
std::vector<llama_token> last_n_tokens;
117117

118+
llama_model * model = nullptr;
118119
llama_context * ctx = nullptr;
119120
gpt_params params;
120121

@@ -130,6 +131,10 @@ struct llama_server_context {
130131
llama_free(ctx);
131132
ctx = nullptr;
132133
}
134+
if (model) {
135+
llama_free_model(model);
136+
model = nullptr;
137+
}
133138
}
134139

135140
void rewind() {
@@ -150,8 +155,8 @@ struct llama_server_context {
150155

151156
bool loadModel(const gpt_params & params_) {
152157
params = params_;
153-
ctx = llama_init_from_gpt_params(params);
154-
if (ctx == nullptr) {
158+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
159+
if (model == nullptr) {
155160
LOG_ERROR("unable to load model", { { "model", params_.model } });
156161
return false;
157162
}

‎examples/simple/simple.cpp

Copy file name to clipboardExpand all lines: examples/simple/simple.cpp
+5-3Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,12 @@ int main(int argc, char ** argv)
6868

6969
llama_init_backend();
7070

71-
llama_context * ctx ;
71+
llama_model * model;
72+
llama_context * ctx;
7273

73-
ctx = llama_init_from_gpt_params( params );
74+
std::tie(model, ctx) = llama_init_from_gpt_params( params );
7475

75-
if ( ctx == NULL )
76+
if ( model == NULL )
7677
{
7778
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
7879
return 1;
@@ -170,6 +171,7 @@ int main(int argc, char ** argv)
170171
} // wend of main loop
171172

172173
llama_free( ctx );
174+
llama_free_model( model );
173175

174176
return 0;
175177
}

‎examples/train-text-from-scratch/train-text-from-scratch.cpp

Copy file name to clipboardExpand all lines: examples/train-text-from-scratch/train-text-from-scratch.cpp
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3054,7 +3054,8 @@ int main(int argc, char ** argv) {
30543054
struct llama_context_params llama_params = llama_context_default_params();
30553055
llama_params.vocab_only = true;
30563056

3057-
struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params);
3057+
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
3058+
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
30583059

30593060
struct llama_vocab vocab;
30603061
{
@@ -3395,6 +3396,8 @@ int main(int argc, char ** argv) {
33953396
delete[] compute_addr;
33963397
delete[] compute_buf_0;
33973398
delete[] compute_buf_1;
3399+
llama_free(lctx);
3400+
llama_free_model(lmodel);
33983401
ggml_free(model.ctx);
33993402

34003403
return 0;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.