Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f0ec6e6

Browse filesBrowse files
committed
Stream tokens instead of text chunks
1 parent 21d8f5f commit f0ec6e6
Copy full SHA for f0ec6e6

File tree

Expand file treeCollapse file tree

1 file changed

+78
-34
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+78
-34
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+78-34Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def _create_completion(
623623
b" " + prompt.encode("utf-8")
624624
)
625625
text: bytes = b""
626-
returned_characters: int = 0
626+
returned_tokens: int = 0
627627
stop = stop if stop is not None else []
628628
model_name: str = model if model is not None else self.model_path
629629

@@ -707,33 +707,42 @@ def _create_completion(
707707
break
708708

709709
if stream:
710-
start = returned_characters
711-
longest = 0
712710
# We want to avoid yielding any characters from
713711
# the generated text if they are part of a stop
714712
# sequence.
713+
longest = 0
715714
for s in stop_sequences:
716715
for i in range(len(s), 0, -1):
717716
if all_text.endswith(s[:i]):
718717
if i > longest:
719718
longest = i
720719
break
721-
text = all_text[: len(all_text) - longest]
722-
returned_characters += len(text[start:])
723-
yield {
724-
"id": completion_id,
725-
"object": "text_completion",
726-
"created": created,
727-
"model": model_name,
728-
"choices": [
729-
{
730-
"text": text[start:].decode("utf-8", errors="ignore"),
731-
"index": 0,
732-
"logprobs": None,
733-
"finish_reason": None,
734-
}
735-
],
736-
}
720+
721+
offset = 0
722+
remaining_tokens = completion_tokens[returned_tokens:]
723+
remaining_length = len(self.detokenize(remaining_tokens))
724+
for token in remaining_tokens:
725+
offset += len(self.detokenize([token]))
726+
# Check if stop sequence is not in the token
727+
if offset >= (remaining_length - longest - 1):
728+
break
729+
returned_tokens += 1
730+
yield {
731+
"id": completion_id,
732+
"object": "text_completion",
733+
"created": created,
734+
"model": model_name,
735+
"choices": [
736+
{
737+
"text": self.detokenize([token]).decode(
738+
"utf-8", errors="ignore"
739+
),
740+
"index": 0,
741+
"logprobs": None,
742+
"finish_reason": None,
743+
}
744+
],
745+
}
737746

738747
if len(completion_tokens) >= max_tokens:
739748
text = self.detokenize(completion_tokens)
@@ -749,22 +758,57 @@ def _create_completion(
749758
llama_cpp.llama_print_timings(self.ctx)
750759

751760
if stream:
752-
yield {
753-
"id": completion_id,
754-
"object": "text_completion",
755-
"created": created,
756-
"model": model_name,
757-
"choices": [
758-
{
759-
"text": text[returned_characters:].decode(
760-
"utf-8", errors="ignore"
761-
),
762-
"index": 0,
763-
"logprobs": None,
764-
"finish_reason": finish_reason,
761+
remaining_tokens = completion_tokens[returned_tokens:]
762+
all_text = self.detokenize(remaining_tokens)
763+
any_stop = [s for s in stop_sequences if s in all_text]
764+
if len(any_stop) > 0:
765+
end = min(all_text.index(stop) for stop in any_stop)
766+
else:
767+
end = len(all_text)
768+
769+
offset = 0
770+
for token in remaining_tokens:
771+
offset += len(self.detokenize([token]))
772+
if offset >= end:
773+
last_text = self.detokenize([token])
774+
if offset == end - 1:
775+
break
776+
yield {
777+
"id": completion_id,
778+
"object": "text_completion",
779+
"created": created,
780+
"model": model_name,
781+
"choices": [
782+
{
783+
"text": last_text[
784+
: len(last_text) - (offset - end)
785+
].decode("utf-8", errors="ignore"),
786+
"index": 0,
787+
"logprobs": None,
788+
"finish_reason": finish_reason,
789+
}
790+
],
765791
}
766-
],
767-
}
792+
break
793+
returned_tokens += 1
794+
yield {
795+
"id": completion_id,
796+
"object": "text_completion",
797+
"created": created,
798+
"model": model_name,
799+
"choices": [
800+
{
801+
"text": self.detokenize([token]).decode(
802+
"utf-8", errors="ignore"
803+
),
804+
"index": 0,
805+
"logprobs": None,
806+
"finish_reason": finish_reason
807+
if returned_tokens == len(completion_tokens)
808+
else None,
809+
}
810+
],
811+
}
768812
return
769813

770814
text_str = text.decode("utf-8", errors="ignore")

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.