@@ -351,55 +351,19 @@ def _create_completion(
351
351
else :
352
352
stop_sequences = []
353
353
354
- text_offset = 0
355
- text_offsets : List [int ] = []
356
- token_logprobs : List [float ] = []
357
- tokens : List [str ] = []
358
- top_logprobs : List [Dict [str , float ]] = []
359
-
360
- self .reset ()
361
- self .eval (prompt_tokens )
362
-
363
354
if logprobs is not None and self .params .logits_all is False :
364
355
raise ValueError (
365
356
"logprobs is not supported for models created with logits_all=False"
366
357
)
367
358
368
- if logprobs is not None :
369
- token_strs = [
370
- self .detokenize ([token ]).decode ("utf-8" ) for token in prompt_tokens
371
- ]
372
- logprobs_all = [
373
- [Llama .logit_to_logprob (logit ) for logit in row ]
374
- for row in self .all_logits
375
- ]
376
- for token , token_str , logprobs_token in zip (
377
- prompt_tokens , token_strs , logprobs_all
378
- ):
379
- text_offsets .append (text_offset )
380
- text_offset += len (token_str )
381
- tokens .append (token_str )
382
- sorted_logprobs = list (
383
- sorted (
384
- zip (logprobs_token , range (len (logprobs_token ))), reverse = True
385
- )
386
- )
387
- token_logprobs .append (sorted_logprobs [int (token )][0 ])
388
- top_logprob = {
389
- self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" ): logprob
390
- for logprob , i in sorted_logprobs [:logprobs ]
391
- }
392
- top_logprob .update ({token_str : sorted_logprobs [int (token )][0 ]})
393
- top_logprobs .append (top_logprob )
394
-
395
359
finish_reason = "length"
396
- while True :
397
- token = self . sample (
398
- top_k = top_k ,
399
- top_p = top_p ,
400
- temp = temperature ,
401
- repeat_penalty = repeat_penalty ,
402
- )
360
+ for token in self . generate (
361
+ prompt_tokens ,
362
+ top_k = top_k ,
363
+ top_p = top_p ,
364
+ temp = temperature ,
365
+ repeat_penalty = repeat_penalty ,
366
+ ):
403
367
if token == llama_cpp .llama_token_eos ():
404
368
text = self .detokenize (completion_tokens )
405
369
finish_reason = "stop"
@@ -443,34 +407,10 @@ def _create_completion(
443
407
],
444
408
}
445
409
446
- if logprobs is not None :
447
- # TODO: Confirm wether this should happen before or after
448
- # next eval.
449
- token_str = self .detokenize ([token ]).decode ("utf-8" )
450
- text_offsets .append (text_offset )
451
- text_offset += len (token_str )
452
- tokens .append (token_str )
453
- logprobs_token = [
454
- Llama .logit_to_logprob (logit ) for logit in self .all_logits [- 1 ]
455
- ]
456
- sorted_logprobs = list (
457
- sorted (
458
- zip (logprobs_token , range (len (logprobs_token ))), reverse = True
459
- )
460
- )
461
- token_logprobs .append (sorted_logprobs [int (token )][0 ])
462
- top_logprob = {
463
- self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" ): logprob
464
- for logprob , i in sorted_logprobs [:logprobs ]
465
- }
466
- top_logprob .update ({token_str : logprobs_token [int (token )]})
467
- top_logprobs .append (top_logprob )
468
-
469
410
if len (completion_tokens ) >= max_tokens :
470
411
text = self .detokenize (completion_tokens )
471
412
finish_reason = "length"
472
413
break
473
- self .eval ([token ])
474
414
475
415
if stream :
476
416
yield {
@@ -499,6 +439,38 @@ def _create_completion(
499
439
500
440
logprobs_or_none : Optional [CompletionLogprobs ] = None
501
441
if logprobs is not None :
442
+ text_offset = 0
443
+ text_offsets : List [int ] = []
444
+ token_logprobs : List [float ] = []
445
+ tokens : List [str ] = []
446
+ top_logprobs : List [Dict [str , float ]] = []
447
+
448
+ all_tokens = prompt_tokens + completion_tokens
449
+ all_token_strs = [
450
+ self .detokenize ([token ]).decode ("utf-8" ) for token in all_tokens
451
+ ]
452
+ all_logprobs = [
453
+ [Llama .logit_to_logprob (logit ) for logit in row ]
454
+ for row in self .all_logits
455
+ ]
456
+ for token , token_str , logprobs_token in zip (
457
+ all_tokens , all_token_strs , all_logprobs
458
+ ):
459
+ text_offsets .append (text_offset )
460
+ text_offset += len (token_str )
461
+ tokens .append (token_str )
462
+ sorted_logprobs = list (
463
+ sorted (
464
+ zip (logprobs_token , range (len (logprobs_token ))), reverse = True
465
+ )
466
+ )
467
+ token_logprobs .append (sorted_logprobs [int (token )][0 ])
468
+ top_logprob = {
469
+ self .detokenize ([llama_cpp .llama_token (i )]).decode ("utf-8" ): logprob
470
+ for logprob , i in sorted_logprobs [:logprobs ]
471
+ }
472
+ top_logprob .update ({token_str : sorted_logprobs [int (token )][0 ]})
473
+ top_logprobs .append (top_logprob )
502
474
logprobs_or_none = {
503
475
"tokens" : tokens ,
504
476
"text_offset" : text_offsets ,
0 commit comments