Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 97c6372

Browse filesBrowse files
committed
Rewind model to longest prefix.
1 parent cabd8b8 commit 97c6372
Copy full SHA for 97c6372

File tree

Expand file treeCollapse file tree

1 file changed

+19
-9
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+19
-9
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+19-9Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -390,18 +390,28 @@ def generate(
390390
"""
391391
assert self.ctx is not None
392392

393-
if (
394-
reset
395-
and len(self.eval_tokens) > 0
396-
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
397-
):
398-
if self.verbose:
399-
print("Llama.generate: cache hit", file=sys.stderr)
400-
reset = False
401-
tokens = tokens[len(self.eval_tokens) :]
393+
if reset and len(self.eval_tokens) > 0:
394+
longest_prefix = 0
395+
for a, b in zip(self.eval_tokens, tokens[:-1]):
396+
if a == b:
397+
longest_prefix += 1
398+
else:
399+
break
400+
if longest_prefix > 0:
401+
if self.verbose:
402+
print("Llama.generate: prefix-match hit", file=sys.stderr)
403+
reset = False
404+
tokens = tokens[longest_prefix:]
405+
for _ in range(len(self.eval_tokens) - longest_prefix):
406+
self.eval_tokens.pop()
407+
try:
408+
self.eval_logits.pop()
409+
except IndexError:
410+
pass
402411

403412
if reset:
404413
self.reset()
414+
405415
while True:
406416
self.eval(tokens)
407417
token = self.sample(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.