Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 027f7bc

Browse filesBrowse files
CISCabetlen
andauthored
fix: Avoid duplicate special tokens in chat formats (abetlen#1439)
* Templates sometimes have BOS in them, remove duplicate * tokenize chat format prompts before completion This is to ensure that we don't duplicate any special tokens. Hopefully I amended the existing formats correctly? * updated comment * corrected a few * add some missing internals * proper bos/eos detection * just let tokenizer do the job * typo-- * align test with new response * changed to a warning * move to another PR * Use python warnings module --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
1 parent 951e39c commit 027f7bc
Copy full SHA for 027f7bc

File tree

Expand file treeCollapse file tree

4 files changed

+25
-10
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+25
-10
lines changed

‎llama_cpp/_internals.py

Copy file name to clipboardExpand all lines: llama_cpp/_internals.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ def token_eos(self) -> int:
142142
assert self.model is not None
143143
return llama_cpp.llama_token_eos(self.model)
144144

145+
def token_cls(self) -> int:
146+
assert self.model is not None
147+
return llama_cpp.llama_token_cls(self.model)
148+
149+
def token_sep(self) -> int:
150+
assert self.model is not None
151+
return llama_cpp.llama_token_sep(self.model)
152+
145153
def token_nl(self) -> int:
146154
assert self.model is not None
147155
return llama_cpp.llama_token_nl(self.model)

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import ctypes
99
import typing
1010
import fnmatch
11+
import warnings
1112
import multiprocessing
1213

1314
from typing import (
@@ -1019,6 +1020,12 @@ def _create_completion(
10191020
)
10201021
model_name: str = model if model is not None else self.model_path
10211022

1023+
if prompt_tokens[:2] == [self.token_bos()] * 2:
1024+
warnings.warn(
1025+
f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
1026+
RuntimeWarning,
1027+
)
1028+
10221029
# NOTE: This likely doesn't work correctly for the first token in the prompt
10231030
# because of the extra space added to the start of the prompt_tokens
10241031
if logit_bias is not None:

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+8-9Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class ChatFormatterResponse:
160160
prompt: str
161161
stop: Optional[Union[str, List[str]]] = None
162162
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
163+
added_special: bool = False
163164

164165

165166
class ChatFormatter(Protocol):
@@ -232,7 +233,7 @@ def stop_on_last_token(
232233
return tokens[-1] in self.stop_token_ids
233234
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
234235

235-
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
236+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria, added_special=True)
236237

237238
def to_chat_handler(self) -> LlamaChatCompletionHandler:
238239
return chat_formatter_to_chat_completion_handler(self)
@@ -548,7 +549,7 @@ def chat_completion_handler(
548549
tools=tools,
549550
tool_choice=tool_choice,
550551
)
551-
prompt = result.prompt
552+
prompt = llama.tokenize(result.prompt.encode("utf-8"), add_bos=not result.added_special, special=True)
552553
if result.stop is not None:
553554
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
554555
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
@@ -655,7 +656,7 @@ def format_autotokenizer(
655656
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
656657
assert isinstance(prompt, str)
657658
# Return formatted prompt and eos token by default
658-
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
659+
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token, added_special=True)
659660

660661
return format_autotokenizer
661662

@@ -708,7 +709,7 @@ def format_tokenizer_config(
708709
bos_token=bos_token,
709710
eos_token=eos_token,
710711
)
711-
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token])
712+
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token], added_special=True)
712713

713714
return format_tokenizer_config
714715

@@ -918,7 +919,7 @@ def format_llama2(
918919
messages: List[llama_types.ChatCompletionRequestMessage],
919920
**kwargs: Any,
920921
) -> ChatFormatterResponse:
921-
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
922+
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
922923
_roles = dict(user="<s>[INST]", assistant="[/INST]")
923924
_messages = _map_roles(messages, _roles)
924925
system_message = _get_system_message(messages)
@@ -940,11 +941,10 @@ def format_llama3(
940941
user="<|start_header_id|>user<|end_header_id|>\n\n",
941942
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
942943
)
943-
_begin_token = "<|begin_of_text|>"
944944
_sep = "<|eot_id|>"
945945
_messages = _map_roles(messages, _roles)
946946
_messages.append((_roles["assistant"], None))
947-
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
947+
_prompt = _format_no_colon_single("", _messages, _sep)
948948
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
949949

950950

@@ -1229,10 +1229,9 @@ def format_mistral_instruct(
12291229
messages: List[llama_types.ChatCompletionRequestMessage],
12301230
**kwargs: Any,
12311231
) -> ChatFormatterResponse:
1232-
bos = "<s>"
12331232
eos = "</s>"
12341233
stop = eos
1235-
prompt = bos
1234+
prompt = ""
12361235
for message in messages:
12371236
if (
12381237
message["role"] == "user"

‎tests/test_llama_chat_format.py

Copy file name to clipboardExpand all lines: tests/test_llama_chat_format.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ def test_mistral_instruct():
2121
response = llama_chat_format.format_mistral_instruct(
2222
messages=messages,
2323
)
24+
prompt = ("" if response.added_special else "<s>") + response.prompt
2425
reference = chat_formatter.render(
2526
messages=messages,
2627
bos_token="<s>",
2728
eos_token="</s>",
2829
)
29-
assert response.prompt == reference
30+
assert prompt == reference
3031

3132

3233
mistral_7b_tokenizer_config = """{

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.