Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit abb8e00

Browse filesBrowse files
compiladehazelnutcloud
authored andcommitted
llama : fix llama_copy_state_data with fragmented KV cache (ggml-org#5840)
The row size of the saved states was based on kv_self.head while it should be based on llama_kv_cache_cell_max. Existing session files should still work. * llama : fix llama_kv_cache_cell_max inability to return 1 I've also changed its return type to uint32_t, because this function is always used to set the value of uint32_t variables, and because the index already has this type. * llama : fix state size calculation Some bytes in the state were unaccounted for in llama_get_state_size. Since the logits reserve so much space, it did not cause problems.
1 parent 7a5c8bd commit abb8e00
Copy full SHA for abb8e00

File tree

Expand file treeCollapse file tree

1 file changed

+30
-17
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+30
-17
lines changed

‎llama.cpp

Copy file name to clipboardExpand all lines: llama.cpp
+30-17Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,10 +2156,12 @@ static bool llama_kv_cache_find_slot(
21562156
}
21572157

21582158
// find how many cells are currently in use
2159-
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160-
for (uint32_t i = cache.size - 1; i > 0; --i) {
2161-
if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
2162-
return i + 1;
2159+
static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
2160+
for (uint32_t i = cache.size; i > 0; --i) {
2161+
const llama_kv_cell & cell = cache.cells[i - 1];
2162+
2163+
if (cell.pos >= 0 && !cell.is_empty()) {
2164+
return i;
21632165
}
21642166
}
21652167

@@ -8178,7 +8180,7 @@ static int llama_decode_internal(
81788180
// a heuristic, to avoid attending the full cache if it is not yet utilized
81798181
// after enough generations, the benefit from this heuristic disappears
81808182
// if we start defragmenting the cache, the benefit from this will be more important
8181-
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
8183+
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
81828184
//kv_self.n = llama_kv_cache_cell_max(kv_self);
81838185

81848186
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@@ -12615,9 +12617,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
1261512617
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
1261612618
const size_t s_embedding_size = sizeof(size_t);
1261712619
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
12618-
const size_t s_kv_size = sizeof(size_t);
12619-
const size_t s_kv_ntok = sizeof(int);
12620+
const size_t s_kv_buf_size = sizeof(size_t);
12621+
const size_t s_kv_head = sizeof(uint32_t);
12622+
const size_t s_kv_size = sizeof(uint32_t);
12623+
const size_t s_kv_used = sizeof(uint32_t);
1262012624
const size_t s_kv = ctx->kv_self.total_size();
12625+
// TODO: assume the max is more than 1 seq_id per KV cell
12626+
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
12627+
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
1262112628

1262212629
const size_t s_total = (
1262312630
+ s_rng_size
@@ -12626,9 +12633,12 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
1262612633
+ s_logits
1262712634
+ s_embedding_size
1262812635
+ s_embedding
12636+
+ s_kv_buf_size
12637+
+ s_kv_head
1262912638
+ s_kv_size
12630-
+ s_kv_ntok
12639+
+ s_kv_used
1263112640
+ s_kv
12641+
+ s_kv_cells
1263212642
);
1263312643

1263412644
return s_total;
@@ -12728,15 +12738,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1272812738
{
1272912739
const auto & kv_self = ctx->kv_self;
1273012740
const auto & hparams = ctx->model.hparams;
12731-
const auto & cparams = ctx->cparams;
1273212741

1273312742
const uint32_t n_layer = hparams.n_layer;
1273412743
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1273512744
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12736-
const uint32_t n_ctx = cparams.n_ctx;
1273712745

1273812746
const size_t kv_buf_size = kv_self.total_size();
12739-
const uint32_t kv_head = kv_self.head;
12747+
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
1274012748
const uint32_t kv_size = kv_self.size;
1274112749
const uint32_t kv_used = kv_self.used;
1274212750

@@ -12756,7 +12764,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1275612764

1275712765
// v is not contiguous, copy row by row
1275812766
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12759-
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
12767+
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
1276012768

1276112769
tmp_buf.resize(v_row_size);
1276212770
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
@@ -12766,7 +12774,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
1276612774
}
1276712775
}
1276812776

12769-
for (uint32_t i = 0; i < kv_size; ++i) {
12777+
for (uint32_t i = 0; i < kv_head; ++i) {
1277012778
const auto & cell = kv_self.cells[i];
1277112779

1277212780
const llama_pos pos = cell.pos;
@@ -12842,12 +12850,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1284212850
{
1284312851
const auto & kv_self = ctx->kv_self;
1284412852
const auto & hparams = ctx->model.hparams;
12845-
const auto & cparams = ctx->cparams;
1284612853

1284712854
const uint32_t n_layer = hparams.n_layer;
1284812855
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1284912856
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
12850-
const uint32_t n_ctx = cparams.n_ctx;
1285112857

1285212858
size_t kv_buf_size;
1285312859
uint32_t kv_head;
@@ -12870,7 +12876,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1287012876

1287112877
// v is not contiguous, copy row by row
1287212878
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
12873-
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
12879+
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
1287412880

1287512881
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
1287612882
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
@@ -12879,13 +12885,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1287912885
}
1288012886
}
1288112887

12888+
GGML_ASSERT(kv_self.size == kv_size);
12889+
1288212890
ctx->kv_self.head = kv_head;
1288312891
ctx->kv_self.size = kv_size;
1288412892
ctx->kv_self.used = kv_used;
1288512893

1288612894
ctx->kv_self.cells.resize(kv_size);
1288712895

12888-
for (uint32_t i = 0; i < kv_size; ++i) {
12896+
for (uint32_t i = 0; i < kv_head; ++i) {
1288912897
llama_pos pos;
1289012898
size_t seq_id_size;
1289112899

@@ -12901,6 +12909,11 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
1290112909
ctx->kv_self.cells[i].seq_id.insert(seq_id);
1290212910
}
1290312911
}
12912+
12913+
for (uint32_t i = kv_head; i < kv_size; ++i) {
12914+
ctx->kv_self.cells[i].pos = -1;
12915+
ctx->kv_self.cells[i].seq_id.clear();
12916+
}
1290412917
}
1290512918

1290612919
const size_t nread = inp - src;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.