Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit da003d8

Browse filesBrowse files
authored
Automatically set chat format from gguf (abetlen#1110)
* Use jinja formatter to load chat format from gguf * Fix off-by-one error in metadata loader * Implement chat format auto-detection
1 parent 059f6b3 commit da003d8
Copy full SHA for da003d8

File tree

Expand file treeCollapse file tree

4 files changed

+68
-7
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+68
-7
lines changed

‎llama_cpp/_internals.py

Copy file name to clipboardExpand all lines: llama_cpp/_internals.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,13 @@ def metadata(self) -> Dict[str, str]:
216216
for i in range(llama_cpp.llama_model_meta_count(self.model)):
217217
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
218218
if nbytes > buffer_size:
219-
buffer_size = nbytes
219+
buffer_size = nbytes + 1
220220
buffer = ctypes.create_string_buffer(buffer_size)
221221
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
222222
key = buffer.value.decode("utf-8")
223223
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
224224
if nbytes > buffer_size:
225-
buffer_size = nbytes
225+
buffer_size = nbytes + 1
226226
buffer = ctypes.create_string_buffer(buffer_size)
227227
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
228228
value = buffer.value.decode("utf-8")

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+36-1Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
# Backend Params
8888
numa: bool = False,
8989
# Chat Format Params
90-
chat_format: str = "llama-2",
90+
chat_format: Optional[str] = None,
9191
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
9292
# Misc
9393
verbose: bool = True,
@@ -343,6 +343,41 @@ def __init__(
343343
if self.verbose:
344344
print(f"Model metadata: {self.metadata}", file=sys.stderr)
345345

346+
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
347+
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
348+
349+
if chat_format is not None:
350+
self.chat_format = chat_format
351+
if self.verbose:
352+
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
353+
else:
354+
template = self.metadata["tokenizer.chat_template"]
355+
try:
356+
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
357+
except:
358+
eos_token_id = self.token_eos()
359+
try:
360+
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
361+
except:
362+
bos_token_id = self.token_bos()
363+
364+
eos_token = self.detokenize([eos_token_id]).decode("utf-8")
365+
bos_token = self.detokenize([bos_token_id]).decode("utf-8")
366+
367+
if self.verbose:
368+
print(f"Using chat template: {template}", file=sys.stderr)
369+
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
370+
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
371+
372+
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
373+
template=template,
374+
eos_token=eos_token,
375+
bos_token=bos_token
376+
).to_chat_handler()
377+
378+
if self.chat_format is None and self.chat_handler is None:
379+
self.chat_format = "llama-2"
380+
346381
@property
347382
def ctx(self) -> llama_cpp.llama_context_p:
348383
assert self._ctx.ctx is not None

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+28-2Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@
1414

1515
from ._utils import suppress_stdout_stderr, Singleton
1616

17+
### Common Chat Templates and Special Tokens ###
18+
19+
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
20+
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
21+
CHATML_BOS_TOKEN = "<s>"
22+
CHATML_EOS_TOKEN = "<|im_end|>"
23+
24+
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
25+
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
26+
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
27+
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
28+
29+
30+
### Chat Completion Handler ###
1731

1832
class LlamaChatCompletionHandler(Protocol):
1933
"""Base Protocol for a llama chat completion handler.
@@ -118,7 +132,6 @@ def decorator(f: LlamaChatCompletionHandler):
118132

119133
### Chat Formatter ###
120134

121-
122135
@dataclasses.dataclass
123136
class ChatFormatterResponse:
124137
"""Dataclass that stores completion parameters for a given chat format and
@@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
440453
return chat_formatter_to_chat_completion_handler(chat_formatter)
441454

442455

456+
def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
457+
if "tokenizer.chat_template" not in metadata:
458+
return None
459+
460+
if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
461+
return "chatml"
462+
463+
if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
464+
return "mistral-instruct"
465+
466+
return None
467+
443468
### Utility functions for formatting chat prompts ###
469+
# TODO: Replace these with jinja2 templates
444470

445471

446472
def _get_system_message(
@@ -929,7 +955,6 @@ def format_openchat(
929955
_prompt = _format_chatml(system_message, _messages, _sep)
930956
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
931957

932-
933958
# Chat format for Saiga models, see more details and available models:
934959
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
935960
@register_chat_format("saiga")
@@ -951,6 +976,7 @@ def format_saiga(
951976
_prompt += "<s>bot"
952977
return ChatFormatterResponse(prompt=_prompt.strip())
953978

979+
# Tricky chat formats that require custom chat handlers
954980

955981
@register_chat_completion_handler("functionary")
956982
def functionary_chat_handler(

‎llama_cpp/server/settings.py

Copy file name to clipboardExpand all lines: llama_cpp/server/settings.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
113113
description="Enable NUMA support.",
114114
)
115115
# Chat Format Params
116-
chat_format: str = Field(
117-
default="llama-2",
116+
chat_format: Optional[str] = Field(
117+
default=None,
118118
description="Chat format to use.",
119119
)
120120
clip_model_path: Optional[str] = Field(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.