Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 92c0771

Browse filesBrowse files
committed
Add experimental cache
1 parent a6372a7 commit 92c0771
Copy full SHA for 92c0771

File tree

Expand file treeCollapse file tree

2 files changed

+69
-5
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+69
-5
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+65-4Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
from .llama_types import *
1212

1313

14+
class LlamaCache:
15+
"""Cache for a llama.cpp model.
16+
17+
NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18+
completion. It does not actually cache the results."""
19+
20+
pass
21+
22+
1423
class Llama:
1524
"""High-level Python wrapper for a llama.cpp model."""
1625

@@ -82,6 +91,14 @@ def __init__(
8291
self.n_past = 0
8392
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
8493

94+
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
95+
### saving and restoring state, this allows us to continue a completion if the last
96+
### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
97+
### because it does not take into account stop tokens which have been processed by the model.
98+
self._completion_bytes: List[bytes] = []
99+
self._cache: Optional[LlamaCache] = None
100+
###
101+
85102
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
86103

87104
if not os.path.exists(model_path):
@@ -135,6 +152,14 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
135152
output += llama_cpp.llama_token_to_str(self.ctx, token)
136153
return output
137154

155+
def set_cache(self, cache: Optional[LlamaCache]):
156+
"""Set the cache.
157+
158+
Args:
159+
cache: The cache to set.
160+
"""
161+
self._cache = cache
162+
138163
def reset(self):
139164
"""Reset the model state."""
140165
self.last_n_tokens_data.extend(
@@ -245,6 +270,17 @@ def generate(
245270
The generated tokens.
246271
"""
247272
assert self.ctx is not None
273+
### HACK
274+
if (
275+
reset
276+
and self._cache
277+
and len(self.tokens) > 0
278+
and self.tokens == tokens[: len(self.tokens)]
279+
):
280+
if self.verbose:
281+
print("generate cache hit", file=sys.stderr)
282+
reset = False
283+
###
248284
if reset:
249285
self.reset()
250286
while True:
@@ -361,13 +397,29 @@ def _create_completion(
361397
"logprobs is not supported for models created with logits_all=False"
362398
)
363399

400+
### HACK
401+
reset: bool = True
402+
_prompt: bytes = prompt.encode("utf-8")
403+
_completion: bytes = b"".join(self._completion_bytes)
404+
if len(_completion) and self._cache and _prompt.startswith(_completion):
405+
if self.verbose:
406+
print("completion cache hit", file=sys.stderr)
407+
reset = False
408+
_prompt = _prompt[len(_completion) :]
409+
prompt_tokens = self.tokenize(b" " + _prompt)
410+
self._completion_bytes.append(_prompt)
411+
else:
412+
self._completion_bytes = [prompt.encode("utf-8")]
413+
###
414+
364415
finish_reason = "length"
365416
for token in self.generate(
366417
prompt_tokens,
367418
top_k=top_k,
368419
top_p=top_p,
369420
temp=temperature,
370421
repeat_penalty=repeat_penalty,
422+
reset=reset,
371423
):
372424
if token == llama_cpp.llama_token_eos():
373425
text = self.detokenize(completion_tokens)
@@ -397,6 +449,9 @@ def _create_completion(
397449
break
398450
text = all_text[: len(all_text) - longest]
399451
returned_characters += len(text[start:])
452+
### HACK
453+
self._completion_bytes.append(text[start:])
454+
###
400455
yield {
401456
"id": completion_id,
402457
"object": "text_completion",
@@ -418,6 +473,9 @@ def _create_completion(
418473
break
419474

420475
if stream:
476+
### HACK
477+
self._completion_bytes.append(text[returned_characters:])
478+
###
421479
yield {
422480
"id": completion_id,
423481
"object": "text_completion",
@@ -434,13 +492,16 @@ def _create_completion(
434492
}
435493
return
436494

437-
text = text.decode("utf-8")
495+
### HACK
496+
self._completion_bytes.append(text)
497+
###
498+
text_str = text.decode("utf-8")
438499

439500
if echo:
440-
text = prompt + text
501+
text_str = prompt + text_str
441502

442503
if suffix is not None:
443-
text = text + suffix
504+
text_str = text_str + suffix
444505

445506
logprobs_or_none: Optional[CompletionLogprobs] = None
446507
if logprobs is not None:
@@ -493,7 +554,7 @@ def _create_completion(
493554
"model": self.model_path,
494555
"choices": [
495556
{
496-
"text": text,
557+
"text": text_str,
497558
"index": 0,
498559
"logprobs": logprobs_or_none,
499560
"finish_reason": finish_reason,

‎llama_cpp/server/__main__.py

Copy file name to clipboardExpand all lines: llama_cpp/server/__main__.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Settings(BaseSettings):
3535
embedding: bool = True
3636
last_n_tokens_size: int = 64
3737
logits_all: bool = False
38+
cache: bool = False # WARNING: This is an experimental feature
3839

3940

4041
app = FastAPI(
@@ -60,6 +61,9 @@ class Settings(BaseSettings):
6061
n_ctx=settings.n_ctx,
6162
last_n_tokens_size=settings.last_n_tokens_size,
6263
)
64+
if settings.cache:
65+
cache = llama_cpp.LlamaCache()
66+
llama.set_cache(cache)
6367
llama_lock = Lock()
6468

6569

@@ -68,7 +72,6 @@ def get_llama():
6872
yield llama
6973

7074

71-
7275
class CreateCompletionRequest(BaseModel):
7376
prompt: Union[str, List[str]]
7477
suffix: Optional[str] = Field(None)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.