Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ca69f32

Browse filesBrowse files
committed
llama : auto-batch
ggml-ci
1 parent f23e4cc commit ca69f32
Copy full SHA for ca69f32

File tree

Expand file treeCollapse file tree

3 files changed

+80
-87
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+80
-87
lines changed

‎src/llama-context.cpp

Copy file name to clipboardExpand all lines: src/llama-context.cpp
+54-36Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,9 @@ const llama_kv_cache * llama_context::get_kv_self() const {
424424
return kv_self;
425425
}
426426

427-
void llama_context::kv_self_update() {
427+
bool llama_context::kv_self_update() {
428428
if (!memory) {
429-
return;
429+
return false;
430430
}
431431

432432
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
@@ -445,7 +445,11 @@ void llama_context::kv_self_update() {
445445
if (!gf) {
446446
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
447447
}
448+
449+
return true;
448450
}
451+
452+
return false;
449453
}
450454

451455
enum llama_pooling_type llama_context::pooling_type() const {
@@ -933,25 +937,53 @@ int llama_context::decode(llama_batch & inp_batch) {
933937
// handle any pending defrags/shifts
934938
kv_self_update();
935939

936-
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
937-
if (!kv_state) {
938-
return -2;
939-
}
940+
llama_memory_state_ptr kv_state;
940941

941-
switch (kv_state->get_status()) {
942-
case LLAMA_MEMORY_STATUS_SUCCESS:
943-
{
944-
} break;
945-
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
946-
{
947-
// not a fatal error, we can re-try with a different batch
948-
return 1;
949-
}
950-
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
951-
{
952-
return -2;
953-
}
954-
}
942+
bool did_defrag = false;
943+
auto n_ubatch = cparams.n_ubatch;
944+
945+
do {
946+
kv_state = kv_self->init_batch(batch, n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
947+
if (!kv_state) {
948+
return -2;
949+
}
950+
951+
switch (kv_state->get_status()) {
952+
case LLAMA_MEMORY_STATUS_SUCCESS:
953+
{
954+
} break;
955+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
956+
{
957+
if (!did_defrag) {
958+
did_defrag = true;
959+
960+
kv_self->defrag_sched(-1.0f);
961+
if (kv_self_update()) {
962+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
963+
964+
continue;
965+
}
966+
}
967+
968+
if (n_ubatch > 1) {
969+
n_ubatch /= 2;
970+
971+
LLAMA_LOG_DEBUG("%s: failed to find free space in the KV cache, retrying with smaller ubatch size: n_ubatch = %d\n", __func__, n_ubatch);
972+
continue;
973+
}
974+
975+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
976+
977+
return 1;
978+
}
979+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
980+
{
981+
return -2;
982+
}
983+
}
984+
985+
break;
986+
} while(true);
955987

956988
// reserve output buffer
957989
if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -2646,22 +2678,8 @@ int32_t llama_encode(
26462678
int32_t llama_decode(
26472679
llama_context * ctx,
26482680
llama_batch batch) {
2649-
int ret = ctx->decode(batch);
2650-
2651-
// defrag and try again
2652-
// TODO: distinguish return code when we are sure that even after defrag there is no space available
2653-
if (ret == 1) {
2654-
llama_kv_self_defrag(ctx);
2655-
ret = ctx->decode(batch);
2656-
2657-
if (ret == 1) {
2658-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
2659-
2660-
return ret;
2661-
}
2662-
}
2663-
2664-
if (ret != 0) {
2681+
const int ret = ctx->decode(batch);
2682+
if (ret != 0 && ret != 1) {
26652683
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
26662684
}
26672685

‎src/llama-context.h

Copy file name to clipboardExpand all lines: src/llama-context.h
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ struct llama_context {
5050
llama_kv_cache * get_kv_self();
5151
const llama_kv_cache * get_kv_self() const;
5252

53+
// return true of the KV cache was updated
5354
// TODO: remove
54-
void kv_self_update();
55+
bool kv_self_update();
5556

5657
enum llama_pooling_type pooling_type() const;
5758

‎tools/server/server.cpp

Copy file name to clipboardExpand all lines: tools/server/server.cpp
+24-50Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3385,75 +3385,49 @@ struct server_context {
33853385
}
33863386

33873387
// process the created batch of tokens
3388-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
3389-
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
3390-
3391-
llama_batch batch_view = {
3392-
n_tokens,
3393-
batch.token + i,
3394-
nullptr,
3395-
batch.pos + i,
3396-
batch.n_seq_id + i,
3397-
batch.seq_id + i,
3398-
batch.logits + i,
3399-
};
3400-
3401-
const int ret = llama_decode(ctx, batch_view);
3402-
3403-
metrics.on_decoded(slots);
3388+
{
3389+
const int ret = llama_decode(ctx, batch);
34043390

34053391
if (ret != 0) {
3406-
{
3407-
std::string err;
3408-
3409-
if (n_batch == 1 && ret == 1) {
3410-
err = "Context size has been exceeded.";
3411-
}
3412-
3413-
if (ret == -1) {
3414-
err = "Invalid input batch.";
3415-
}
3392+
std::string err;
34163393

3417-
if (ret < -1) {
3418-
err = "Compute error.";
3419-
}
3420-
3421-
if (!err.empty()) {
3422-
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
3423-
for (auto & slot : slots) {
3424-
slot.release();
3425-
send_error(slot, err);
3426-
}
3427-
break;
3428-
}
3394+
if (ret == 1) {
3395+
err = "Context size has been exceeded.";
34293396
}
34303397

3431-
// retry with half the batch size to try to find a free slot in the KV cache
3432-
n_batch /= 2;
3398+
if (ret == -1) {
3399+
err = "Invalid input batch.";
3400+
}
34333401

3434-
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
3402+
if (ret < -1) {
3403+
err = "Compute error.";
3404+
}
34353405

3436-
i -= n_batch;
3406+
if (!err.empty()) {
3407+
SRV_ERR("%s, n_batch = %d, ret = %d\n", err.c_str(), n_batch, ret);
3408+
for (auto & slot : slots) {
3409+
slot.release();
3410+
send_error(slot, err);
3411+
}
34373412

3438-
continue; // continue loop of n_batch
3413+
return;
3414+
}
34393415
}
34403416

3441-
for (auto & slot : slots) {
3442-
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
3443-
continue; // continue loop of slots
3444-
}
3417+
metrics.on_decoded(slots);
34453418

3419+
for (auto & slot : slots) {
34463420
if (slot.state == SLOT_STATE_DONE_PROMPT) {
34473421
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
34483422
// prompt evaluated for embedding
3449-
send_embedding(slot, batch_view);
3423+
send_embedding(slot, batch);
34503424
slot.release();
34513425
slot.i_batch = -1;
34523426
continue; // continue loop of slots
34533427
}
34543428

34553429
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
3456-
send_rerank(slot, batch_view);
3430+
send_rerank(slot, batch);
34573431
slot.release();
34583432
slot.i_batch = -1;
34593433
continue; // continue loop of slots
@@ -3465,7 +3439,7 @@ struct server_context {
34653439
continue; // continue loop of slots
34663440
}
34673441

3468-
const int tok_idx = slot.i_batch - i;
3442+
const int tok_idx = slot.i_batch;
34693443

34703444
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
34713445

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.