Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d484c56

Browse filesBrowse files
committed
Bugfix: Check cache keys as prefix to prompt tokens
1 parent b75fa96 commit d484c56
Copy full SHA for d484c56

File tree

Expand file treeCollapse file tree

1 file changed

+26
-5
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+26
-5
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+26-5Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import math
66
import multiprocessing
7-
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
7+
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
88
from collections import deque
99

1010
from . import llama_cpp
@@ -15,15 +15,34 @@ class LlamaCache:
1515
"""Cache for a llama.cpp model."""
1616

1717
def __init__(self):
18-
self.cache_state: Dict[Sequence[llama_cpp.llama_token], "LlamaState"] = dict()
18+
self.cache_state: Dict[Tuple[llama_cpp.llama_token, ...], "LlamaState"] = dict()
19+
20+
def _sorted_keys(self) -> List[Tuple[llama_cpp.llama_token, ...]]:
21+
return [
22+
key
23+
for _, key in sorted(
24+
((len(key), key) for key in self.cache_state.keys()), reverse=True
25+
)
26+
]
27+
28+
def _find_key(
29+
self, key: Tuple[llama_cpp.llama_token, ...]
30+
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
31+
for k in self._sorted_keys():
32+
if key[: len(k)] == k:
33+
return k
34+
return None
1935

2036
def __getitem__(
2137
self, key: Sequence[llama_cpp.llama_token]
2238
) -> Optional["LlamaState"]:
23-
return self.cache_state.get(tuple(key), None)
39+
_key = self._find_key(tuple(key))
40+
if _key is None:
41+
return None
42+
return self.cache_state[_key]
2443

2544
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
26-
return tuple(key) in self.cache_state
45+
return self._find_key(tuple(key)) is not None
2746

2847
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
2948
self.cache_state = dict() # NOTE: Currently limit to one cache entry.
@@ -295,7 +314,7 @@ def generate(
295314
if (
296315
reset
297316
and len(self.eval_tokens) > 0
298-
and self.eval_tokens == tokens[: len(self.eval_tokens)]
317+
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
299318
):
300319
if self.verbose:
301320
print("generate cache hit", file=sys.stderr)
@@ -438,6 +457,8 @@ def _create_completion(
438457

439458
if self.cache and len(completion_tokens) == 0:
440459
if prompt_tokens not in self.cache:
460+
if self.verbose:
461+
print("cache miss", file=sys.stderr)
441462
self.cache[prompt_tokens] = self.save_state()
442463

443464
completion_tokens.append(token)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.