Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 17d4271

Browse filesBrowse files
committed
Fix logprobs for completions and implement for streaming logprobs.
1 parent a634a24 commit 17d4271
Copy full SHA for 17d4271

File tree

Expand file treeCollapse file tree

1 file changed

+103
-22
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+103
-22
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+103-22Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -710,22 +710,56 @@ def _create_completion(
710710
# We want to avoid yielding any characters from
711711
# the generated text if they are part of a stop
712712
# sequence.
713-
longest = 0
713+
first_stop_position = 0
714714
for s in stop_sequences:
715715
for i in range(len(s), 0, -1):
716716
if all_text.endswith(s[:i]):
717-
if i > longest:
718-
longest = i
717+
if i > first_stop_position:
718+
first_stop_position = i
719719
break
720720

721-
offset = 0
721+
token_end_position = 0
722722
remaining_tokens = completion_tokens[returned_tokens:]
723723
remaining_length = len(self.detokenize(remaining_tokens))
724724
for token in remaining_tokens:
725-
offset += len(self.detokenize([token]))
726-
# Check if stop sequence is not in the token
727-
if offset >= (remaining_length - longest - 1):
725+
token_end_position += len(self.detokenize([token]))
726+
# Check if stop sequence is in the token
727+
if token_end_position >= (remaining_length - first_stop_position - 1):
728728
break
729+
logprobs_or_none: Optional[CompletionLogprobs] = None
730+
if logprobs is not None:
731+
token_str = self.detokenize([token]).decode(
732+
"utf-8", errors="ignore"
733+
)
734+
text_offset = len(prompt) + len(
735+
self.detokenize(completion_tokens[:returned_tokens])
736+
)
737+
token_offset = len(prompt_tokens) + returned_tokens
738+
logits = self.eval_logits[token_offset - 1]
739+
current_logprobs = Llama.logits_to_logprobs(logits)
740+
sorted_logprobs = list(
741+
sorted(
742+
zip(current_logprobs, range(len(current_logprobs))),
743+
reverse=True,
744+
)
745+
)
746+
top_logprob = {
747+
self.detokenize([llama_cpp.llama_token(i)]).decode(
748+
"utf-8", errors="ignore"
749+
): logprob
750+
for logprob, i in sorted_logprobs[:logprobs]
751+
}
752+
top_logprob.update({token_str: current_logprobs[int(token)]})
753+
logprobs_or_none = {
754+
"tokens": [
755+
self.detokenize([token]).decode(
756+
"utf-8", errors="ignore"
757+
)
758+
],
759+
"text_offset": [text_offset],
760+
"token_logprobs": [sorted_logprobs[int(token)][0]],
761+
"top_logprobs": [top_logprob],
762+
}
729763
returned_tokens += 1
730764
yield {
731765
"id": completion_id,
@@ -738,7 +772,7 @@ def _create_completion(
738772
"utf-8", errors="ignore"
739773
),
740774
"index": 0,
741-
"logprobs": None,
775+
"logprobs": logprobs_or_none,
742776
"finish_reason": None,
743777
}
744778
],
@@ -766,13 +800,48 @@ def _create_completion(
766800
else:
767801
end = len(all_text)
768802

769-
offset = 0
803+
token_end_position = 0
770804
for token in remaining_tokens:
771-
offset += len(self.detokenize([token]))
772-
if offset >= end:
805+
token_end_position += len(self.detokenize([token]))
806+
807+
logprobs_or_none: Optional[CompletionLogprobs] = None
808+
if logprobs is not None:
809+
token_str = self.detokenize([token]).decode(
810+
"utf-8", errors="ignore"
811+
)
812+
text_offset = len(prompt) + len(
813+
self.detokenize(completion_tokens[:returned_tokens])
814+
)
815+
token_offset = len(prompt_tokens) + returned_tokens - 1
816+
logits = self.eval_logits[token_offset]
817+
current_logprobs = Llama.logits_to_logprobs(logits)
818+
sorted_logprobs = list(
819+
sorted(
820+
zip(current_logprobs, range(len(current_logprobs))),
821+
reverse=True,
822+
)
823+
)
824+
top_logprob = {
825+
self.detokenize([llama_cpp.llama_token(i)]).decode(
826+
"utf-8", errors="ignore"
827+
): logprob
828+
for logprob, i in sorted_logprobs[:logprobs]
829+
}
830+
top_logprob.update({token_str: current_logprobs[int(token)]})
831+
logprobs_or_none = {
832+
"tokens": [
833+
self.detokenize([token]).decode("utf-8", errors="ignore")
834+
],
835+
"text_offset": [text_offset],
836+
"token_logprobs": [sorted_logprobs[int(token)][0]],
837+
"top_logprobs": [top_logprob],
838+
}
839+
840+
if token_end_position >= end:
773841
last_text = self.detokenize([token])
774-
if offset == end - 1:
842+
if token_end_position == end - 1:
775843
break
844+
returned_tokens += 1
776845
yield {
777846
"id": completion_id,
778847
"object": "text_completion",
@@ -781,10 +850,10 @@ def _create_completion(
781850
"choices": [
782851
{
783852
"text": last_text[
784-
: len(last_text) - (offset - end)
853+
: len(last_text) - (token_end_position - end)
785854
].decode("utf-8", errors="ignore"),
786855
"index": 0,
787-
"logprobs": None,
856+
"logprobs": logprobs_or_none,
788857
"finish_reason": finish_reason,
789858
}
790859
],
@@ -802,7 +871,7 @@ def _create_completion(
802871
"utf-8", errors="ignore"
803872
),
804873
"index": 0,
805-
"logprobs": None,
874+
"logprobs": logprobs_or_none,
806875
"finish_reason": finish_reason
807876
if returned_tokens == len(completion_tokens)
808877
else None,
@@ -821,21 +890,27 @@ def _create_completion(
821890

822891
logprobs_or_none: Optional[CompletionLogprobs] = None
823892
if logprobs is not None:
824-
text_offset = 0
893+
text_offset = 0 if echo else len(prompt)
894+
token_offset = 0 if echo else len(prompt_tokens[1:])
825895
text_offsets: List[int] = []
826-
token_logprobs: List[float] = []
896+
token_logprobs: List[Optional[float]] = []
827897
tokens: List[str] = []
828-
top_logprobs: List[Dict[str, float]] = []
898+
top_logprobs: List[Optional[Dict[str, float]]] = []
899+
900+
if echo:
901+
# Remove leading BOS token
902+
all_tokens = prompt_tokens[1:] + completion_tokens
903+
else:
904+
all_tokens = completion_tokens
829905

830-
all_tokens = prompt_tokens + completion_tokens
831906
all_token_strs = [
832907
self.detokenize([token]).decode("utf-8", errors="ignore")
833908
for token in all_tokens
834909
]
835910
all_logprobs = [
836911
Llama.logits_to_logprobs(list(map(float, row)))
837912
for row in self.eval_logits
838-
]
913+
][token_offset:]
839914
for token, token_str, logprobs_token in zip(
840915
all_tokens, all_token_strs, all_logprobs
841916
):
@@ -848,14 +923,20 @@ def _create_completion(
848923
)
849924
)
850925
token_logprobs.append(sorted_logprobs[int(token)][0])
851-
top_logprob = {
926+
top_logprob: Optional[Dict[str, float]] = {
852927
self.detokenize([llama_cpp.llama_token(i)]).decode(
853928
"utf-8", errors="ignore"
854929
): logprob
855930
for logprob, i in sorted_logprobs[:logprobs]
856931
}
857-
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
932+
top_logprob.update({token_str: logprobs_token[int(token)]})
858933
top_logprobs.append(top_logprob)
934+
# Weird idosincracy of the OpenAI API where
935+
# token_logprobs and top_logprobs are null for
936+
# the first token.
937+
if echo and len(all_tokens) > 0:
938+
token_logprobs[0] = None
939+
top_logprobs[0] = None
859940
logprobs_or_none = {
860941
"tokens": tokens,
861942
"text_offset": text_offsets,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.