Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9ae5819

Browse filesBrowse files
committed
Add chat format test.
1 parent ce38dbd commit 9ae5819
Copy full SHA for 9ae5819

File tree

Expand file treeCollapse file tree

2 files changed

+35
-10
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+35
-10
lines changed

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+12-10Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -878,19 +878,21 @@ def format_chatml(
878878

879879

880880
@register_chat_format("mistral-instruct")
881-
def format_mistral(
881+
def format_mistral_instruct(
882882
messages: List[llama_types.ChatCompletionRequestMessage],
883883
**kwargs: Any,
884884
) -> ChatFormatterResponse:
885-
_roles = dict(user="[INST] ", assistant="[/INST]")
886-
_sep = " "
887-
system_template = """<s>{system_message}"""
888-
system_message = _get_system_message(messages)
889-
system_message = system_template.format(system_message=system_message)
890-
_messages = _map_roles(messages, _roles)
891-
_messages.append((_roles["assistant"], None))
892-
_prompt = _format_no_colon_single(system_message, _messages, _sep)
893-
return ChatFormatterResponse(prompt=_prompt)
885+
bos = "<s>"
886+
eos = "</s>"
887+
stop = eos
888+
prompt = bos
889+
for message in messages:
890+
if message["role"] == "user" and message["content"] is not None and isinstance(message["content"], str):
891+
prompt += "[INST] " + message["content"]
892+
elif message["role"] == "assistant" and message["content"] is not None and isinstance(message["content"], str):
893+
prompt += " [/INST]" + message["content"] + eos
894+
prompt += " [/INST]"
895+
return ChatFormatterResponse(prompt=prompt, stop=stop)
894896

895897

896898
@register_chat_format("chatglm3")

‎tests/test_llama_chat_format.py

Copy file name to clipboardExpand all lines: tests/test_llama_chat_format.py
+23Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
11
import json
22

3+
import jinja2
4+
35
from llama_cpp import (
46
ChatCompletionRequestUserMessage,
57
)
8+
import llama_cpp.llama_types as llama_types
9+
import llama_cpp.llama_chat_format as llama_chat_format
10+
611
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
712

13+
def test_mistral_instruct():
14+
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
15+
chat_formatter = jinja2.Template(chat_template)
16+
messages = [
17+
llama_types.ChatCompletionRequestUserMessage(role="user", content="Instruction"),
18+
llama_types.ChatCompletionRequestAssistantMessage(role="assistant", content="Model answer"),
19+
llama_types.ChatCompletionRequestUserMessage(role="user", content="Follow-up instruction"),
20+
]
21+
response = llama_chat_format.format_mistral_instruct(
22+
messages=messages,
23+
)
24+
reference = chat_formatter.render(
25+
messages=messages,
26+
bos_token="<s>",
27+
eos_token="</s>",
28+
)
29+
assert response.prompt == reference
30+
831

932
mistral_7b_tokenizer_config = """{
1033
"add_bos_token": true,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.