Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b95b0ff

Browse filesBrowse files
committed
Use pre-allocated buffers to store input_ids and scores
1 parent a5e059c commit b95b0ff
Copy full SHA for b95b0ff

File tree

Expand file treeCollapse file tree

1 file changed

+44
-42
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+44
-42
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+44-42Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141
if _key is None:
142142
raise KeyError("Key not found")
143143
value: "LlamaState" = self.cache.pop(_key) # type: ignore
144-
# NOTE: This puts an integer as key in cache, which breaks,
144+
# NOTE: This puts an integer as key in cache, which breaks,
145145
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146146
# self.cache.push(_key, side="front") # type: ignore
147147
return value
@@ -166,17 +166,15 @@ def __setitem__(self, key: Sequence[int], value: "LlamaState"):
166166
class LlamaState:
167167
def __init__(
168168
self,
169-
eval_tokens: Deque[int],
170-
eval_logits: Deque[List[float]],
171169
input_ids: npt.NDArray[np.intc],
172170
scores: npt.NDArray[np.single],
171+
n_tokens: int,
173172
llama_state: bytes,
174173
llama_state_size: int,
175174
):
176-
self.eval_tokens = eval_tokens
177-
self.eval_logits = eval_logits
178175
self.input_ids = input_ids
179176
self.scores = scores
177+
self.n_tokens = n_tokens
180178
self.llama_state = llama_state
181179
self.llama_state_size = llama_state_size
182180

@@ -267,8 +265,6 @@ def __init__(
267265

268266
self.last_n_tokens_size = last_n_tokens_size
269267
self.n_batch = min(n_ctx, n_batch)
270-
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
271-
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
272268

273269
self.cache: Optional[BaseLlamaCache] = None
274270

@@ -329,8 +325,30 @@ def __init__(
329325
self._token_nl = Llama.token_nl()
330326
self._token_eos = Llama.token_eos()
331327

332-
self._input_ids = np.array([], dtype=np.intc)
333-
self._scores: npt.NDArray[np.single] = np.ndarray((0, self._n_vocab), dtype=np.single)
328+
self.n_tokens = 0
329+
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
330+
self.scores: npt.NDArray[np.single] = np.ndarray(
331+
(n_ctx, self._n_vocab), dtype=np.single
332+
)
333+
334+
@property
335+
def _input_ids(self) -> npt.NDArray[np.intc]:
336+
return self.input_ids[: self.n_tokens]
337+
338+
@property
339+
def _scores(self) -> npt.NDArray[np.single]:
340+
return self.scores[: self.n_tokens, :]
341+
342+
@property
343+
def eval_tokens(self) -> Deque[int]:
344+
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
345+
346+
@property
347+
def eval_logits(self) -> Deque[List[float]]:
348+
return deque(
349+
self.scores[: self.n_tokens, :].tolist(),
350+
maxlen=self._n_ctx if self.params.logits_all else 1,
351+
)
334352

335353
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
336354
"""Tokenize a string.
@@ -397,10 +415,7 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
397415

398416
def reset(self):
399417
"""Reset the model state."""
400-
self.eval_tokens.clear()
401-
self.eval_logits.clear()
402-
self._input_ids = np.array([], dtype=np.intc)
403-
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
418+
self.n_tokens = 0
404419

405420
def eval(self, tokens: Sequence[int]):
406421
"""Evaluate a list of tokens.
@@ -410,7 +425,6 @@ def eval(self, tokens: Sequence[int]):
410425
"""
411426
assert self.ctx is not None
412427
n_ctx = self._n_ctx
413-
scores: List[npt.NDArray[np.single]] = []
414428
for i in range(0, len(tokens), self.n_batch):
415429
batch = tokens[i : min(len(tokens), i + self.n_batch)]
416430
n_past = min(n_ctx - len(batch), len(self._input_ids))
@@ -425,19 +439,16 @@ def eval(self, tokens: Sequence[int]):
425439
if return_code != 0:
426440
raise RuntimeError(f"llama_eval returned {return_code}")
427441
# Save tokens
428-
self.eval_tokens.extend(batch)
429-
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
430-
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
431-
)
442+
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
432443
# Save logits
433444
rows = n_tokens if self.params.logits_all else 1
434445
n_vocab = self._n_vocab
435446
cols = n_vocab
436447
logits_view = llama_cpp.llama_get_logits(self.ctx)
437448
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
438-
self.eval_logits.extend(logits)
439-
scores.append(np.array(logits, dtype=np.single))
440-
self._scores = np.concatenate(scores)
449+
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
450+
# Update n_tokens
451+
self.n_tokens += n_tokens
441452

442453
def _sample(
443454
self,
@@ -457,8 +468,7 @@ def _sample(
457468
logits_processor: Optional[LogitsProcessorList] = None,
458469
):
459470
assert self.ctx is not None
460-
assert len(self.eval_logits) > 0
461-
assert self._scores.shape[0] > 0
471+
assert self.n_tokens > 0
462472
n_vocab = self._n_vocab
463473
n_ctx = self._n_ctx
464474
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@@ -475,7 +485,6 @@ def _sample(
475485
dtype=np.single,
476486
)
477487
self._scores[-1, :] = logits
478-
self.eval_logits[-1] = logits.tolist()
479488

480489
nl_logit = logits[self._token_nl]
481490
candidates = self._candidates
@@ -672,14 +681,7 @@ def generate(
672681
print("Llama.generate: prefix-match hit", file=sys.stderr)
673682
reset = False
674683
tokens = tokens[longest_prefix:]
675-
self._input_ids = self._input_ids[:longest_prefix]
676-
self._scores = self._scores[:longest_prefix, :]
677-
for _ in range(len(self.eval_tokens) - longest_prefix):
678-
self.eval_tokens.pop()
679-
try:
680-
self.eval_logits.pop()
681-
except IndexError:
682-
pass
684+
self.n_tokens = longest_prefix
683685

684686
if reset:
685687
self.reset()
@@ -819,7 +821,9 @@ def _create_completion(
819821
llama_cpp.llama_reset_timings(self.ctx)
820822

821823
if len(prompt_tokens) > self._n_ctx:
822-
raise ValueError(f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}")
824+
raise ValueError(
825+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}"
826+
)
823827

824828
# Truncate max_tokens if requested tokens would exceed the context window
825829
max_tokens = (
@@ -1513,22 +1517,20 @@ def save_state(self) -> LlamaState:
15131517
file=sys.stderr,
15141518
)
15151519
return LlamaState(
1516-
eval_tokens=self.eval_tokens.copy(),
1517-
eval_logits=self.eval_logits.copy(),
1518-
scores=self._scores.copy(),
1519-
input_ids=self._input_ids.copy(),
1520+
scores=self.scores.copy(),
1521+
input_ids=self.input_ids.copy(),
1522+
n_tokens=self.n_tokens,
15201523
llama_state=bytes(llama_state_compact),
15211524
llama_state_size=n_bytes,
15221525
)
15231526

15241527
def load_state(self, state: LlamaState) -> None:
15251528
assert self.ctx is not None
1526-
self.eval_tokens = state.eval_tokens.copy()
1527-
self.eval_logits = state.eval_logits.copy()
1528-
self._scores = state.scores.copy()
1529-
self._input_ids = state.input_ids.copy()
1529+
self.scores = state.scores.copy()
1530+
self.input_ids = state.input_ids.copy()
1531+
self.n_tokens = state.n_tokens
15301532
state_size = state.llama_state_size
1531-
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
1533+
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
15321534
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
15331535

15341536
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.