11
11
from .llama_types import *
12
12
13
13
14
+ class LlamaCache :
15
+ """Cache for a llama.cpp model.
16
+
17
+ NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18
+ completion. It does not actually cache the results."""
19
+
20
+ pass
21
+
22
+
14
23
class Llama :
15
24
"""High-level Python wrapper for a llama.cpp model."""
16
25
@@ -82,6 +91,14 @@ def __init__(
82
91
self .n_past = 0
83
92
self .all_logits : List [List [float ]] = [] # TODO: Use an array instead of a list.
84
93
94
+ ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
95
+ ### saving and restoring state, this allows us to continue a completion if the last
96
+ ### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
97
+ ### because it does not take into account stop tokens which have been processed by the model.
98
+ self ._completion_bytes : List [bytes ] = []
99
+ self ._cache : Optional [LlamaCache ] = None
100
+ ###
101
+
85
102
self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
86
103
87
104
if not os .path .exists (model_path ):
@@ -135,6 +152,14 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
135
152
output += llama_cpp .llama_token_to_str (self .ctx , token )
136
153
return output
137
154
155
+ def set_cache (self , cache : Optional [LlamaCache ]):
156
+ """Set the cache.
157
+
158
+ Args:
159
+ cache: The cache to set.
160
+ """
161
+ self ._cache = cache
162
+
138
163
def reset (self ):
139
164
"""Reset the model state."""
140
165
self .last_n_tokens_data .extend (
@@ -245,6 +270,17 @@ def generate(
245
270
The generated tokens.
246
271
"""
247
272
assert self .ctx is not None
273
+ ### HACK
274
+ if (
275
+ reset
276
+ and self ._cache
277
+ and len (self .tokens ) > 0
278
+ and self .tokens == tokens [: len (self .tokens )]
279
+ ):
280
+ if self .verbose :
281
+ print ("generate cache hit" , file = sys .stderr )
282
+ reset = False
283
+ ###
248
284
if reset :
249
285
self .reset ()
250
286
while True :
@@ -361,13 +397,29 @@ def _create_completion(
361
397
"logprobs is not supported for models created with logits_all=False"
362
398
)
363
399
400
+ ### HACK
401
+ reset : bool = True
402
+ _prompt : bytes = prompt .encode ("utf-8" )
403
+ _completion : bytes = b"" .join (self ._completion_bytes )
404
+ if len (_completion ) and self ._cache and _prompt .startswith (_completion ):
405
+ if self .verbose :
406
+ print ("completion cache hit" , file = sys .stderr )
407
+ reset = False
408
+ _prompt = _prompt [len (_completion ) :]
409
+ prompt_tokens = self .tokenize (b" " + _prompt )
410
+ self ._completion_bytes .append (_prompt )
411
+ else :
412
+ self ._completion_bytes = [prompt .encode ("utf-8" )]
413
+ ###
414
+
364
415
finish_reason = "length"
365
416
for token in self .generate (
366
417
prompt_tokens ,
367
418
top_k = top_k ,
368
419
top_p = top_p ,
369
420
temp = temperature ,
370
421
repeat_penalty = repeat_penalty ,
422
+ reset = reset ,
371
423
):
372
424
if token == llama_cpp .llama_token_eos ():
373
425
text = self .detokenize (completion_tokens )
@@ -397,6 +449,9 @@ def _create_completion(
397
449
break
398
450
text = all_text [: len (all_text ) - longest ]
399
451
returned_characters += len (text [start :])
452
+ ### HACK
453
+ self ._completion_bytes .append (text [start :])
454
+ ###
400
455
yield {
401
456
"id" : completion_id ,
402
457
"object" : "text_completion" ,
@@ -418,6 +473,9 @@ def _create_completion(
418
473
break
419
474
420
475
if stream :
476
+ ### HACK
477
+ self ._completion_bytes .append (text [returned_characters :])
478
+ ###
421
479
yield {
422
480
"id" : completion_id ,
423
481
"object" : "text_completion" ,
@@ -434,13 +492,16 @@ def _create_completion(
434
492
}
435
493
return
436
494
437
- text = text .decode ("utf-8" )
495
+ ### HACK
496
+ self ._completion_bytes .append (text )
497
+ ###
498
+ text_str = text .decode ("utf-8" )
438
499
439
500
if echo :
440
- text = prompt + text
501
+ text_str = prompt + text_str
441
502
442
503
if suffix is not None :
443
- text = text + suffix
504
+ text_str = text_str + suffix
444
505
445
506
logprobs_or_none : Optional [CompletionLogprobs ] = None
446
507
if logprobs is not None :
@@ -493,7 +554,7 @@ def _create_completion(
493
554
"model" : self .model_path ,
494
555
"choices" : [
495
556
{
496
- "text" : text ,
557
+ "text" : text_str ,
497
558
"index" : 0 ,
498
559
"logprobs" : logprobs_or_none ,
499
560
"finish_reason" : finish_reason ,
0 commit comments