@@ -710,22 +710,56 @@ def _create_completion(
710
710
# We want to avoid yielding any characters from
711
711
# the generated text if they are part of a stop
712
712
# sequence.
713
- longest = 0
713
+ first_stop_position = 0
714
714
for s in stop_sequences :
715
715
for i in range (len (s ), 0 , - 1 ):
716
716
if all_text .endswith (s [:i ]):
717
- if i > longest :
718
- longest = i
717
+ if i > first_stop_position :
718
+ first_stop_position = i
719
719
break
720
720
721
- offset = 0
721
+ token_end_position = 0
722
722
remaining_tokens = completion_tokens [returned_tokens :]
723
723
remaining_length = len (self .detokenize (remaining_tokens ))
724
724
for token in remaining_tokens :
725
- offset += len (self .detokenize ([token ]))
726
- # Check if stop sequence is not in the token
727
- if offset >= (remaining_length - longest - 1 ):
725
+ token_end_position += len (self .detokenize ([token ]))
726
+ # Check if stop sequence is in the token
727
+ if token_end_position >= (remaining_length - first_stop_position - 1 ):
728
728
break
729
+ logprobs_or_none : Optional [CompletionLogprobs ] = None
730
+ if logprobs is not None :
731
+ token_str = self .detokenize ([token ]).decode (
732
+ "utf-8" , errors = "ignore"
733
+ )
734
+ text_offset = len (prompt ) + len (
735
+ self .detokenize (completion_tokens [:returned_tokens ])
736
+ )
737
+ token_offset = len (prompt_tokens ) + returned_tokens
738
+ logits = self .eval_logits [token_offset - 1 ]
739
+ current_logprobs = Llama .logits_to_logprobs (logits )
740
+ sorted_logprobs = list (
741
+ sorted (
742
+ zip (current_logprobs , range (len (current_logprobs ))),
743
+ reverse = True ,
744
+ )
745
+ )
746
+ top_logprob = {
747
+ self .detokenize ([llama_cpp .llama_token (i )]).decode (
748
+ "utf-8" , errors = "ignore"
749
+ ): logprob
750
+ for logprob , i in sorted_logprobs [:logprobs ]
751
+ }
752
+ top_logprob .update ({token_str : current_logprobs [int (token )]})
753
+ logprobs_or_none = {
754
+ "tokens" : [
755
+ self .detokenize ([token ]).decode (
756
+ "utf-8" , errors = "ignore"
757
+ )
758
+ ],
759
+ "text_offset" : [text_offset ],
760
+ "token_logprobs" : [sorted_logprobs [int (token )][0 ]],
761
+ "top_logprobs" : [top_logprob ],
762
+ }
729
763
returned_tokens += 1
730
764
yield {
731
765
"id" : completion_id ,
@@ -738,7 +772,7 @@ def _create_completion(
738
772
"utf-8" , errors = "ignore"
739
773
),
740
774
"index" : 0 ,
741
- "logprobs" : None ,
775
+ "logprobs" : logprobs_or_none ,
742
776
"finish_reason" : None ,
743
777
}
744
778
],
@@ -766,13 +800,48 @@ def _create_completion(
766
800
else :
767
801
end = len (all_text )
768
802
769
- offset = 0
803
+ token_end_position = 0
770
804
for token in remaining_tokens :
771
- offset += len (self .detokenize ([token ]))
772
- if offset >= end :
805
+ token_end_position += len (self .detokenize ([token ]))
806
+
807
+ logprobs_or_none : Optional [CompletionLogprobs ] = None
808
+ if logprobs is not None :
809
+ token_str = self .detokenize ([token ]).decode (
810
+ "utf-8" , errors = "ignore"
811
+ )
812
+ text_offset = len (prompt ) + len (
813
+ self .detokenize (completion_tokens [:returned_tokens ])
814
+ )
815
+ token_offset = len (prompt_tokens ) + returned_tokens - 1
816
+ logits = self .eval_logits [token_offset ]
817
+ current_logprobs = Llama .logits_to_logprobs (logits )
818
+ sorted_logprobs = list (
819
+ sorted (
820
+ zip (current_logprobs , range (len (current_logprobs ))),
821
+ reverse = True ,
822
+ )
823
+ )
824
+ top_logprob = {
825
+ self .detokenize ([llama_cpp .llama_token (i )]).decode (
826
+ "utf-8" , errors = "ignore"
827
+ ): logprob
828
+ for logprob , i in sorted_logprobs [:logprobs ]
829
+ }
830
+ top_logprob .update ({token_str : current_logprobs [int (token )]})
831
+ logprobs_or_none = {
832
+ "tokens" : [
833
+ self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
834
+ ],
835
+ "text_offset" : [text_offset ],
836
+ "token_logprobs" : [sorted_logprobs [int (token )][0 ]],
837
+ "top_logprobs" : [top_logprob ],
838
+ }
839
+
840
+ if token_end_position >= end :
773
841
last_text = self .detokenize ([token ])
774
- if offset == end - 1 :
842
+ if token_end_position == end - 1 :
775
843
break
844
+ returned_tokens += 1
776
845
yield {
777
846
"id" : completion_id ,
778
847
"object" : "text_completion" ,
@@ -781,10 +850,10 @@ def _create_completion(
781
850
"choices" : [
782
851
{
783
852
"text" : last_text [
784
- : len (last_text ) - (offset - end )
853
+ : len (last_text ) - (token_end_position - end )
785
854
].decode ("utf-8" , errors = "ignore" ),
786
855
"index" : 0 ,
787
- "logprobs" : None ,
856
+ "logprobs" : logprobs_or_none ,
788
857
"finish_reason" : finish_reason ,
789
858
}
790
859
],
@@ -802,7 +871,7 @@ def _create_completion(
802
871
"utf-8" , errors = "ignore"
803
872
),
804
873
"index" : 0 ,
805
- "logprobs" : None ,
874
+ "logprobs" : logprobs_or_none ,
806
875
"finish_reason" : finish_reason
807
876
if returned_tokens == len (completion_tokens )
808
877
else None ,
@@ -821,21 +890,27 @@ def _create_completion(
821
890
822
891
logprobs_or_none : Optional [CompletionLogprobs ] = None
823
892
if logprobs is not None :
824
- text_offset = 0
893
+ text_offset = 0 if echo else len (prompt )
894
+ token_offset = 0 if echo else len (prompt_tokens [1 :])
825
895
text_offsets : List [int ] = []
826
- token_logprobs : List [float ] = []
896
+ token_logprobs : List [Optional [ float ] ] = []
827
897
tokens : List [str ] = []
828
- top_logprobs : List [Dict [str , float ]] = []
898
+ top_logprobs : List [Optional [Dict [str , float ]]] = []
899
+
900
+ if echo :
901
+ # Remove leading BOS token
902
+ all_tokens = prompt_tokens [1 :] + completion_tokens
903
+ else :
904
+ all_tokens = completion_tokens
829
905
830
- all_tokens = prompt_tokens + completion_tokens
831
906
all_token_strs = [
832
907
self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
833
908
for token in all_tokens
834
909
]
835
910
all_logprobs = [
836
911
Llama .logits_to_logprobs (list (map (float , row )))
837
912
for row in self .eval_logits
838
- ]
913
+ ][ token_offset :]
839
914
for token , token_str , logprobs_token in zip (
840
915
all_tokens , all_token_strs , all_logprobs
841
916
):
@@ -848,14 +923,20 @@ def _create_completion(
848
923
)
849
924
)
850
925
token_logprobs .append (sorted_logprobs [int (token )][0 ])
851
- top_logprob = {
926
+ top_logprob : Optional [ Dict [ str , float ]] = {
852
927
self .detokenize ([llama_cpp .llama_token (i )]).decode (
853
928
"utf-8" , errors = "ignore"
854
929
): logprob
855
930
for logprob , i in sorted_logprobs [:logprobs ]
856
931
}
857
- top_logprob .update ({token_str : sorted_logprobs [int (token )][ 0 ]})
932
+ top_logprob .update ({token_str : logprobs_token [int (token )]})
858
933
top_logprobs .append (top_logprob )
934
+ # Weird idosincracy of the OpenAI API where
935
+ # token_logprobs and top_logprobs are null for
936
+ # the first token.
937
+ if echo and len (all_tokens ) > 0 :
938
+ token_logprobs [0 ] = None
939
+ top_logprobs [0 ] = None
859
940
logprobs_or_none = {
860
941
"tokens" : tokens ,
861
942
"text_offset" : text_offsets ,
0 commit comments