Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 49fe939

Browse filesBrowse files
authored
Merge pull request abetlen#277 from abetlen/add-numpy-support
Use numpy for internal buffers
2 parents b61b016 + 8f2b445 commit 49fe939
Copy full SHA for 49fe939

File tree

Expand file treeCollapse file tree

3 files changed

+60
-35
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+60
-35
lines changed

‎CHANGELOG.md

Copy file name to clipboardExpand all lines: CHANGELOG.md
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- Added first version of the changelog
13+
- Use numpy for internal buffers to reduce memory usage and improve performance.
1314

1415
### Fixed
1516

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+58-32Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from . import llama_cpp
2121
from .llama_types import *
2222

23+
import numpy as np
24+
import numpy.typing as npt
25+
2326

2427
class LlamaCache:
2528
"""Cache for a llama.cpp model."""
@@ -73,11 +76,15 @@ def __init__(
7376
self,
7477
eval_tokens: Deque[int],
7578
eval_logits: Deque[List[float]],
79+
input_ids: npt.NDArray[np.intc],
80+
scores: npt.NDArray[np.single],
7681
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
7782
llama_state_size: int,
7883
):
7984
self.eval_tokens = eval_tokens
8085
self.eval_logits = eval_logits
86+
self.input_ids = input_ids
87+
self.scores = scores
8188
self.llama_state = llama_state
8289
self.llama_state_size = llama_state_size
8390

@@ -207,27 +214,27 @@ def __init__(
207214

208215
self._n_vocab = self.n_vocab()
209216
self._n_ctx = self.n_ctx()
210-
data = (llama_cpp.llama_token_data * self._n_vocab)(
211-
*[
212-
llama_cpp.llama_token_data(
213-
id=llama_cpp.llama_token(i),
214-
logit=llama_cpp.c_float(0.0),
215-
p=llama_cpp.c_float(0.0),
216-
)
217-
for i in range(self._n_vocab)
218-
]
219-
)
220217
size = llama_cpp.c_size_t(self._n_vocab)
221-
sorted = False
218+
sorted = llama_cpp.c_bool(False)
219+
self._candidates_data = np.array(
220+
[],
221+
dtype=np.dtype(
222+
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
223+
),
224+
)
225+
self._candidates_data.resize(3, self._n_vocab)
222226
candidates = llama_cpp.llama_token_data_array(
223-
data=data,
227+
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
224228
size=size,
225229
sorted=sorted,
226230
)
227231
self._candidates = candidates
228232
self._token_nl = Llama.token_nl()
229233
self._token_eos = Llama.token_eos()
230234

235+
self._input_ids = np.array([], dtype=np.intc)
236+
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
237+
231238
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
232239
"""Tokenize a string.
233240
@@ -295,6 +302,8 @@ def reset(self):
295302
"""Reset the model state."""
296303
self.eval_tokens.clear()
297304
self.eval_logits.clear()
305+
self._input_ids = np.array([], dtype=np.intc)
306+
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
298307

299308
def eval(self, tokens: Sequence[int]):
300309
"""Evaluate a list of tokens.
@@ -306,7 +315,7 @@ def eval(self, tokens: Sequence[int]):
306315
n_ctx = self._n_ctx
307316
for i in range(0, len(tokens), self.n_batch):
308317
batch = tokens[i : min(len(tokens), i + self.n_batch)]
309-
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
318+
n_past = min(n_ctx - len(batch), len(self._input_ids))
310319
n_tokens = len(batch)
311320
return_code = llama_cpp.llama_eval(
312321
ctx=self.ctx,
@@ -319,13 +328,19 @@ def eval(self, tokens: Sequence[int]):
319328
raise RuntimeError(f"llama_eval returned {return_code}")
320329
# Save tokens
321330
self.eval_tokens.extend(batch)
331+
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
332+
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
333+
)
322334
# Save logits
323335
rows = n_tokens if self.params.logits_all else 1
324336
n_vocab = self._n_vocab
325337
cols = n_vocab
326338
logits_view = llama_cpp.llama_get_logits(self.ctx)
327339
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
328340
self.eval_logits.extend(logits)
341+
self._scores: npt.NDArray[np.single] = np.concatenate(
342+
(self._scores, np.array(logits, dtype=np.single)), axis=0
343+
)
329344

330345
def _sample(
331346
self,
@@ -346,6 +361,7 @@ def _sample(
346361
):
347362
assert self.ctx is not None
348363
assert len(self.eval_logits) > 0
364+
assert self._scores.shape[0] > 0
349365
n_vocab = self._n_vocab
350366
n_ctx = self._n_ctx
351367
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@@ -354,18 +370,23 @@ def _sample(
354370
if last_n_tokens_size.value < 0
355371
else last_n_tokens_size
356372
)
357-
logits = self.eval_logits[-1]
373+
logits: npt.NDArray[np.single] = self._scores[-1, :]
358374

359375
if logits_processor is not None:
360-
logits = logits_processor(list(self.eval_tokens), logits)
361-
self.eval_logits[-1] = logits
376+
logits = np.array(
377+
logits_processor(self._input_ids.tolist(), logits.tolist()),
378+
dtype=np.single,
379+
)
380+
self._scores[-1, :] = logits
381+
self.eval_logits[-1] = logits.tolist()
362382

363383
nl_logit = logits[self._token_nl]
364384
candidates = self._candidates
365-
for i, logit in enumerate(logits):
366-
candidates.data[i].id = llama_cpp.llama_token(i)
367-
candidates.data[i].logit = llama_cpp.c_float(logit)
368-
candidates.data[i].p = llama_cpp.c_float(0.0)
385+
candidates_data = self._candidates_data
386+
candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore
387+
candidates_data["logit"] = logits
388+
candidates_data["p"] = np.zeros(n_vocab, dtype=np.single)
389+
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
369390
candidates.sorted = llama_cpp.c_bool(False)
370391
candidates.size = llama_cpp.c_size_t(n_vocab)
371392
llama_cpp.llama_sample_repetition_penalty(
@@ -483,8 +504,8 @@ def sample(
483504
"""
484505
assert self.ctx is not None
485506
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
486-
0, self.last_n_tokens_size - len(self.eval_tokens)
487-
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
507+
0, self.last_n_tokens_size - len(self._input_ids)
508+
) + self._input_ids[-self.last_n_tokens_size :].tolist()
488509
return self._sample(
489510
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
490511
*last_n_tokens_data
@@ -542,9 +563,9 @@ def generate(
542563
"""
543564
assert self.ctx is not None
544565

545-
if reset and len(self.eval_tokens) > 0:
566+
if reset and len(self._input_ids) > 0:
546567
longest_prefix = 0
547-
for a, b in zip(self.eval_tokens, tokens[:-1]):
568+
for a, b in zip(self._input_ids, tokens[:-1]):
548569
if a == b:
549570
longest_prefix += 1
550571
else:
@@ -554,6 +575,8 @@ def generate(
554575
print("Llama.generate: prefix-match hit", file=sys.stderr)
555576
reset = False
556577
tokens = tokens[longest_prefix:]
578+
self._input_ids = self._input_ids[:longest_prefix]
579+
self._scores = self._scores[:longest_prefix, :]
557580
for _ in range(len(self.eval_tokens) - longest_prefix):
558581
self.eval_tokens.pop()
559582
try:
@@ -580,7 +603,7 @@ def generate(
580603
logits_processor=logits_processor,
581604
)
582605
if stopping_criteria is not None and stopping_criteria(
583-
list(self.eval_tokens), self.eval_logits[-1]
606+
self._input_ids.tolist(), self._scores[-1, :].tolist()
584607
):
585608
return
586609
tokens_or_none = yield token
@@ -715,10 +738,10 @@ def _create_completion(
715738
try:
716739
cache_item = self.cache[prompt_tokens]
717740
cache_prefix_len = Llama.longest_token_prefix(
718-
cache_item.eval_tokens, prompt_tokens
741+
cache_item.input_ids.tolist(), prompt_tokens
719742
)
720743
eval_prefix_len = Llama.longest_token_prefix(
721-
self.eval_tokens, prompt_tokens
744+
self._input_ids.tolist(), prompt_tokens
722745
)
723746
if cache_prefix_len > eval_prefix_len:
724747
self.load_state(cache_item)
@@ -807,7 +830,7 @@ def _create_completion(
807830
self.detokenize(completion_tokens[:returned_tokens])
808831
)
809832
token_offset = len(prompt_tokens) + returned_tokens
810-
logits = self.eval_logits[token_offset - 1]
833+
logits = self._scores[token_offset - 1, :].tolist()
811834
current_logprobs = Llama.logits_to_logprobs(logits)
812835
sorted_logprobs = list(
813836
sorted(
@@ -856,7 +879,7 @@ def _create_completion(
856879
break
857880

858881
if stopping_criteria is not None and stopping_criteria(
859-
list(self.eval_tokens), self.eval_logits[-1]
882+
self._input_ids.tolist(), self._scores[-1, :].tolist()
860883
):
861884
text = self.detokenize(completion_tokens)
862885
finish_reason = "stop"
@@ -886,7 +909,7 @@ def _create_completion(
886909
self.detokenize(completion_tokens[:returned_tokens])
887910
)
888911
token_offset = len(prompt_tokens) + returned_tokens - 1
889-
logits = self.eval_logits[token_offset]
912+
logits = self._scores[token_offset, :].tolist()
890913
current_logprobs = Llama.logits_to_logprobs(logits)
891914
sorted_logprobs = list(
892915
sorted(
@@ -988,8 +1011,7 @@ def _create_completion(
9881011
for token in all_tokens
9891012
]
9901013
all_logprobs = [
991-
Llama.logits_to_logprobs(list(map(float, row)))
992-
for row in self.eval_logits
1014+
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
9931015
][token_offset:]
9941016
for token, token_str, logprobs_token in zip(
9951017
all_tokens, all_token_strs, all_logprobs
@@ -1373,6 +1395,8 @@ def save_state(self) -> LlamaState:
13731395
return LlamaState(
13741396
eval_tokens=self.eval_tokens.copy(),
13751397
eval_logits=self.eval_logits.copy(),
1398+
scores=self._scores.copy(),
1399+
input_ids=self._input_ids.copy(),
13761400
llama_state=llama_state_compact,
13771401
llama_state_size=n_bytes,
13781402
)
@@ -1381,6 +1405,8 @@ def load_state(self, state: LlamaState) -> None:
13811405
assert self.ctx is not None
13821406
self.eval_tokens = state.eval_tokens.copy()
13831407
self.eval_logits = state.eval_logits.copy()
1408+
self._scores = state.scores.copy()
1409+
self._input_ids = state.input_ids.copy()
13841410
state_size = state.llama_state_size
13851411
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
13861412
raise RuntimeError("Failed to set llama state data")

‎setup.py

Copy file name to clipboardExpand all lines: setup.py
+1-3Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
license="MIT",
1717
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
1818
packages=["llama_cpp", "llama_cpp.server"],
19-
install_requires=[
20-
"typing-extensions>=4.5.0",
21-
],
19+
install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0"],
2220
extras_require={
2321
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
2422
},

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.