@@ -127,7 +127,9 @@ def __init__(
127
127
self .last_n_tokens_size = last_n_tokens_size
128
128
self .n_batch = min (n_ctx , n_batch )
129
129
self .eval_tokens : Deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
130
- self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx )
130
+ self .eval_logits : Deque [List [llama_cpp .c_float ]] = deque (
131
+ maxlen = n_ctx if logits_all else 1
132
+ )
131
133
132
134
self .cache : Optional [LlamaCache ] = None
133
135
@@ -236,17 +238,90 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
236
238
)
237
239
if int (return_code ) != 0 :
238
240
raise RuntimeError (f"llama_eval returned { return_code } " )
241
+ # Save tokens
239
242
self .eval_tokens .extend (batch )
240
- if self .params .logits_all :
241
- n_vocab = llama_cpp .llama_n_vocab (self .ctx )
242
- cols = int (n_vocab )
243
- rows = n_tokens
244
- logits_view = llama_cpp .llama_get_logits (self .ctx )
245
- logits = [
246
- [logits_view [i * cols + j ] for j in range (cols )]
247
- for i in range (rows )
248
- ]
249
- self .eval_logits .extend (logits )
243
+ # Save logits
244
+ rows = n_tokens if self .params .logits_all else 1
245
+ n_vocab = llama_cpp .llama_n_vocab (self .ctx )
246
+ cols = int (n_vocab )
247
+ logits_view = llama_cpp .llama_get_logits (self .ctx )
248
+ logits : List [List [llama_cpp .c_float ]] = [
249
+ [logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )
250
+ ]
251
+ self .eval_logits .extend (logits )
252
+
253
+ def _sample_top_p_top_k (
254
+ self ,
255
+ last_n_tokens_data , # type: llama_cpp.Array[llama_cpp.llama_token]
256
+ last_n_tokens_size : llama_cpp .c_int ,
257
+ top_k : llama_cpp .c_int ,
258
+ top_p : llama_cpp .c_float ,
259
+ temp : llama_cpp .c_float ,
260
+ repeat_penalty : llama_cpp .c_float ,
261
+ ):
262
+ assert self .ctx is not None
263
+ assert len (self .eval_logits ) > 0
264
+ n_vocab = int (llama_cpp .llama_n_vocab (self .ctx ))
265
+ logits = self .eval_logits [- 1 ]
266
+ data = (llama_cpp .llama_token_data * n_vocab )(
267
+ * [
268
+ llama_cpp .llama_token_data (
269
+ id = llama_cpp .llama_token (i ),
270
+ logit = logits [i ],
271
+ p = llama_cpp .c_float (0.0 ),
272
+ )
273
+ for i in range (n_vocab )
274
+ ]
275
+ )
276
+ size = llama_cpp .c_size_t (n_vocab )
277
+ sorted = False
278
+ candidates = llama_cpp .llama_token_data_array (
279
+ data = data ,
280
+ size = size ,
281
+ sorted = sorted ,
282
+ )
283
+ llama_cpp .llama_sample_repetition_penalty (
284
+ ctx = self .ctx ,
285
+ last_tokens_data = last_n_tokens_data ,
286
+ last_tokens_size = last_n_tokens_size ,
287
+ candidates = llama_cpp .ctypes .pointer (candidates ),
288
+ penalty = repeat_penalty ,
289
+ )
290
+ if temp == 0.0 :
291
+ return llama_cpp .llama_sample_token_greedy (
292
+ ctx = self .ctx ,
293
+ candidates = llama_cpp .ctypes .pointer (candidates ),
294
+ )
295
+ else :
296
+ llama_cpp .llama_sample_top_k (
297
+ ctx = self .ctx ,
298
+ candidates = llama_cpp .ctypes .pointer (candidates ),
299
+ k = top_k ,
300
+ )
301
+ llama_cpp .llama_sample_tail_free (
302
+ ctx = self .ctx ,
303
+ candidates = llama_cpp .ctypes .pointer (candidates ),
304
+ z = llama_cpp .c_float (1.0 ),
305
+ )
306
+ llama_cpp .llama_sample_typical (
307
+ ctx = self .ctx ,
308
+ candidates = llama_cpp .ctypes .pointer (candidates ),
309
+ p = llama_cpp .c_float (1.0 )
310
+ )
311
+ llama_cpp .llama_sample_top_p (
312
+ ctx = self .ctx ,
313
+ candidates = llama_cpp .ctypes .pointer (candidates ),
314
+ p = top_p ,
315
+ )
316
+ llama_cpp .llama_sample_temperature (
317
+ ctx = self .ctx ,
318
+ candidates = llama_cpp .ctypes .pointer (candidates ),
319
+ temp = temp ,
320
+ )
321
+ return llama_cpp .llama_sample_token (
322
+ ctx = self .ctx ,
323
+ candidates = llama_cpp .ctypes .pointer (candidates ),
324
+ )
250
325
251
326
def sample (
252
327
self ,
@@ -270,8 +345,7 @@ def sample(
270
345
last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
271
346
0 , self .last_n_tokens_size - len (self .eval_tokens )
272
347
) + list (self .eval_tokens )[- self .last_n_tokens_size :]
273
- return llama_cpp .llama_sample_top_p_top_k (
274
- ctx = self .ctx ,
348
+ return self ._sample_top_p_top_k (
275
349
last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
276
350
* last_n_tokens_data
277
351
),
@@ -470,15 +544,15 @@ def _create_completion(
470
544
all_text = self .detokenize (completion_tokens )
471
545
472
546
# Contains multi-byte UTF8
473
- for k ,char in enumerate (all_text [- 3 :]):
547
+ for k , char in enumerate (all_text [- 3 :]):
474
548
k = 3 - k
475
- for num ,pattern in [(2 , 192 ), (3 , 224 ), (4 , 240 )]:
549
+ for num , pattern in [(2 , 192 ), (3 , 224 ), (4 , 240 )]:
476
550
# Bitwise AND check
477
- if ( num > k and pattern & char == pattern ) :
551
+ if num > k and pattern & char == pattern :
478
552
multibyte_fix = num - k
479
553
480
554
# Stop incomplete bytes from passing
481
- if ( multibyte_fix > 0 ) :
555
+ if multibyte_fix > 0 :
482
556
multibyte_fix -= 1
483
557
continue
484
558
@@ -531,7 +605,9 @@ def _create_completion(
531
605
"model" : self .model_path ,
532
606
"choices" : [
533
607
{
534
- "text" : text [returned_characters :].decode ("utf-8" , errors = "ignore" ),
608
+ "text" : text [returned_characters :].decode (
609
+ "utf-8" , errors = "ignore"
610
+ ),
535
611
"index" : 0 ,
536
612
"logprobs" : None ,
537
613
"finish_reason" : finish_reason ,
@@ -558,7 +634,8 @@ def _create_completion(
558
634
559
635
all_tokens = prompt_tokens + completion_tokens
560
636
all_token_strs = [
561
- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" ) for token in all_tokens
637
+ self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
638
+ for token in all_tokens
562
639
]
563
640
all_logprobs = [
564
641
[Llama .logit_to_logprob (logit ) for logit in row ]
@@ -577,7 +654,9 @@ def _create_completion(
577
654
)
578
655
token_logprobs .append (sorted_logprobs [int (token )][0 ])
579
656
top_logprob = {
580
- self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" , errors = "ignore" ): logprob
657
+ self .detokenize ([llama_cpp .llama_token (i )]).decode (
658
+ "utf-8" , errors = "ignore"
659
+ ): logprob
581
660
for logprob , i in sorted_logprobs [:logprobs ]
582
661
}
583
662
top_logprob .update ({token_str : sorted_logprobs [int (token )][0 ]})
0 commit comments