Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2252eef

Browse filesBrowse files
committed
kv-cache : improve find_slot() using min/max seq pos info
ggml-ci
1 parent 332f460 commit 2252eef
Copy full SHA for 2252eef

File tree

Expand file treeCollapse file tree

5 files changed

+63
-58
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+63
-58
lines changed

‎src/llama-batch.cpp

Copy file name to clipboardExpand all lines: src/llama-batch.cpp
-20Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,6 @@
44
#include <cstring>
55
#include <algorithm>
66

7-
void llama_ubatch::update() {
8-
if (equal_seqs) {
9-
// TODO: for now don't compute min/max for recurrent batches since we don't need this.
10-
// the batches will be refactored anyway, so we'll fix this later
11-
return;
12-
}
13-
14-
for (uint32_t i = 0; i < n_tokens; ++i) {
15-
const llama_seq_id s = seq_id[i][0];
16-
17-
seq_pos_min[s] = seq_pos_min[s] == -1 ? pos[i] : std::min(seq_pos_min[s], pos[i]);
18-
seq_pos_max[s] = seq_pos_max[s] == -1 ? pos[i] : std::max(seq_pos_max[s], pos[i]);
19-
}
20-
}
21-
227
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
238
// clear empty sequences
249
// the previous ubatch is assumed to be gone,
@@ -47,8 +32,6 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
4732
/*n_tokens =*/ 0,
4833
/*n_seq_tokens =*/ 0,
4934
/*n_seqs =*/ 0,
50-
/*seq_pos_min =*/ {-1},
51-
/*seq_pos_max =*/ {-1},
5235
/*token =*/ !has_embd ? udata.token.data() : nullptr,
5336
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
5437
/*pos =*/ udata.pos.data(),
@@ -172,7 +155,6 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
172155
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
173156
add_seq_to_ubatch(ubatch, s, length);
174157
}
175-
ubatch.update();
176158
return ubatch;
177159
}
178160

@@ -200,7 +182,6 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
200182
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
201183
}
202184
}
203-
ubatch.update();
204185
return ubatch;
205186
}
206187

@@ -213,7 +194,6 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
213194
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
214195
add_seq_to_ubatch(ubatch, s, length);
215196
}
216-
ubatch.update();
217197
return ubatch;
218198
}
219199

‎src/llama-batch.h

Copy file name to clipboardExpand all lines: src/llama-batch.h
-6Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
#pragma once
22

33
#include "llama.h"
4-
#include "llama-cparams.h"
54

65
#include <array>
76
#include <vector>
87

98
// very similar to llama_batch,
109
// but has more metadata about sequences
1110
struct llama_ubatch {
12-
void update();
13-
1411
bool equal_seqs;
1512
// TODO: whole_seqs for embeddings?
1613

1714
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
1815
uint32_t n_seq_tokens; // tokens per sequence
1916
uint32_t n_seqs;
2017

21-
llama_pos seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; // min position of each sequence
22-
llama_pos seq_pos_max[LLAMA_MAX_PARALLEL_SEQUENCES]; // max position of each sequence
23-
2418
llama_token * token; // [n_tokens]
2519
float * embd; // [n_embd, n_tokens]
2620
llama_pos * pos; // [n_tokens]

‎src/llama-context.cpp

Copy file name to clipboardExpand all lines: src/llama-context.cpp
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
12331233
this->n_outputs = n_outputs;
12341234

12351235
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
1236-
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, {-1}, {-1}, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
1236+
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
12371237

12381238
auto * gf = graph_init();
12391239
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);

‎src/llama-kv-cache.cpp

Copy file name to clipboardExpand all lines: src/llama-kv-cache.cpp
+37-15Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
548548
if (cells.is_empty(i)) {
549549
ss += '.';
550550
} else {
551-
ss += 'x';
551+
ss += std::to_string(cells.seq_get(i));
552552
}
553553
if (i%256 == 255) {
554554
ss += '\n';
@@ -557,6 +557,10 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
557557
}
558558
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
559559
}
560+
561+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[0] = %5d, max[0] = %5d\n", n_swa, cells.seq_pos_min(0), cells.seq_pos_max(0));
562+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[1] = %5d, max[1] = %5d\n", n_swa, cells.seq_pos_min(1), cells.seq_pos_max(1));
563+
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[2] = %5d, max[2] = %5d\n", n_swa, cells.seq_pos_min(2), cells.seq_pos_max(2));
560564
#endif
561565

562566
uint32_t n_tested = 0;
@@ -568,24 +572,44 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
568572
continue;
569573
}
570574

575+
// keep track of what the minimum sequence positions would be if we accept the ubatch
576+
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
577+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
578+
seq_pos_min[s] = cells.seq_pos_min(s);
579+
}
580+
571581
bool found = true;
572582
for (uint32_t i = 0; i < n_tokens; i++) {
573583
const llama_pos pos = ubatch.pos[i];
574584
const llama_seq_id seq_id = ubatch.seq_id[i][0];
575585

576586
// can we use this cell? either:
577587
// - the cell is empty
578-
// - the cell is occupied only by the same sequence, and the pos is masked
579-
const bool can_use =
580-
cells.is_empty(head_cur + i) ||
581-
(
582-
cells.seq_has (head_cur + i, seq_id) && // sequence mask
583-
cells.seq_count(head_cur + i) == 1 &&
584-
(
585-
cells.pos_get (head_cur + i) >= pos || // causal mask
586-
is_masked_swa(cells.pos_get(head_cur + i), ubatch.seq_pos_min[seq_id]) // SWA mask
587-
)
588-
);
588+
// - the cell is occupied only by one sequence:
589+
// - mask causally, if the sequence is the same as the one we are inserting
590+
// - mask SWA, using current max pos for that sequence in the cache
591+
// always insert in the cell with minimum pos
592+
bool can_use = cells.is_empty(head_cur + i);
593+
594+
if (!can_use && cells.seq_count(head_cur + i) == 1) {
595+
const llama_pos pos_cell = cells.pos_get(head_cur + i);
596+
597+
// causal mask
598+
if (cells.seq_has(head_cur + i, seq_id)) {
599+
can_use = pos_cell >= pos;
600+
}
601+
602+
if (!can_use) {
603+
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
604+
605+
// SWA mask
606+
if (pos_cell == seq_pos_min[seq_id_cell] &&
607+
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
608+
seq_pos_min[seq_id_cell]++;
609+
can_use = true;
610+
}
611+
}
612+
}
589613

590614
if (!can_use) {
591615
found = false;
@@ -613,9 +637,7 @@ void llama_kv_cache_unified::fill_slot(uint32_t head_cur, const llama_ubatch & u
613637

614638
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
615639
if (!cells.is_empty(head + i)) {
616-
cells.pos_chg(head + i, ubatch.pos[i]);
617-
618-
continue;
640+
cells.rm(head + i);
619641
}
620642

621643
cells.pos_set(head + i, ubatch.pos[i]);

‎src/llama-kv-cells.h

Copy file name to clipboardExpand all lines: src/llama-kv-cells.h
+25-16Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ class llama_kv_cells_unified {
6868
// the index of the last cell that is used + 1
6969
// return 0 if no cells are used
7070
uint32_t used_max_p1() const {
71-
#if 0
72-
if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73-
if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74-
if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75-
#endif
76-
7771
return used.empty() ? 0 : *used.rbegin() + 1;
7872
}
7973

@@ -144,6 +138,18 @@ class llama_kv_cells_unified {
144138
}
145139
}
146140

141+
void rm(uint32_t i) {
142+
assert(i < pos.size());
143+
assert(pos[i] != -1);
144+
145+
seq_pos_rm(i);
146+
147+
pos[i] = -1;
148+
seq[i].reset();
149+
150+
used.erase(i);
151+
}
152+
147153
// note: call only if the cell has seq_id
148154
// return true if the cell becomes empty
149155
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
@@ -220,6 +226,18 @@ class llama_kv_cells_unified {
220226
seq_pos[seq_id].insert(pos[i]);
221227
}
222228

229+
llama_seq_id seq_get(uint32_t i) const {
230+
assert(seq[i].count() == 1);
231+
232+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
233+
if (seq[i].test(s)) {
234+
return s;
235+
}
236+
}
237+
238+
return -1;
239+
}
240+
223241
// the minimum position of sequence seq_id currently present in any of the cells
224242
// return -1 if the sequence is not present
225243
llama_pos seq_pos_min(llama_seq_id seq_id) const {
@@ -275,22 +293,13 @@ class llama_kv_cells_unified {
275293
void pos_set(uint32_t i, llama_pos p) {
276294
assert(i < pos.size());
277295
assert(pos[i] == -1);
296+
assert(seq[i].none());
278297

279298
pos[i] = p;
280299

281300
used.insert(i);
282301
}
283302

284-
// change the position of a non-empty cell
285-
// does not modify "has_shift"
286-
// note: call only if the cell is not empty
287-
void pos_chg(uint32_t i, llama_pos p) {
288-
assert(i < pos.size());
289-
assert(pos[i] != -1);
290-
291-
pos[i] = p;
292-
}
293-
294303
// pos[i] = pos[i] + d
295304
// sets "has_shift" to true
296305
// note: call only if the cell is not empty

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.