Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit be09318

Browse filesBrowse files
committed
feat: Add Jinja2ChatFormatter
1 parent 5a34c57 commit be09318
Copy full SHA for be09318

File tree

Expand file treeCollapse file tree

1 file changed

+188
-135
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+188
-135
lines changed

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+188-135Lines changed: 188 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,21 @@ def decorator(f: LlamaChatCompletionHandler):
121121

122122
@dataclasses.dataclass
123123
class ChatFormatterResponse:
124+
"""Dataclass that stores completion parameters for a given chat format and
125+
create_chat_completion request.
126+
127+
prompt contains the formatted prompt generated from the chat format and messages.
128+
stop contains the stop token or list of stop tokens to use for the chat format."""
129+
124130
prompt: str
125131
stop: Optional[Union[str, List[str]]] = None
126132

127133

128134
class ChatFormatter(Protocol):
129135
"""Base Protocol for a chat formatter. A chat formatter is a function that
130-
takes a list of messages and returns a formatted prompt. It can also return
131-
a stop token or list of stop tokens to use for the completion."""
136+
takes a list of messages and returns a chat format response which can be used
137+
to generate a completion. The response can also include a stop token or list
138+
of stop tokens to use for the completion."""
132139

133140
def __call__(
134141
self,
@@ -139,131 +146,43 @@ def __call__(
139146
...
140147

141148

142-
### Utility functions for formatting chat prompts ###
143-
144-
145-
def _get_system_message(
146-
messages: List[llama_types.ChatCompletionRequestMessage],
147-
) -> str:
148-
"""Get the first system message."""
149-
for message in messages:
150-
if message["role"] == "system":
151-
return message["content"] or ""
152-
return ""
153-
154-
155-
def _map_roles(
156-
messages: List[llama_types.ChatCompletionRequestMessage],
157-
role_map: Dict[str, str],
158-
) -> List[Tuple[str, Optional[str]]]:
159-
"""Map the message roles."""
160-
output: List[Tuple[str, Optional[str]]] = []
161-
for message in messages:
162-
role = message["role"]
163-
if role in role_map:
164-
content: str | None = (
165-
message["content"] if isinstance(message["content"], str) else None
166-
)
167-
output.append((role_map[role], content))
168-
return output
169-
170-
171-
def _format_llama2(
172-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
173-
) -> str:
174-
"""Format the prompt with the llama2 style."""
175-
seps = [sep, sep2]
176-
ret = system_message + sep
177-
for i, (role, message) in enumerate(messages):
178-
if system_message and i == 0:
179-
m = message or ""
180-
ret += m + seps[i % 2]
181-
elif message:
182-
ret += role + message + " " + seps[i % 2]
183-
else:
184-
ret += role + " "
185-
return ret
186-
187-
188-
def _format_add_colon_single(
189-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
190-
) -> str:
191-
"""Format the prompt with the add-colon-single style."""
192-
ret = system_message + sep
193-
for role, message in messages:
194-
if message:
195-
ret += role + ": " + message + sep
196-
else:
197-
ret += role + ":"
198-
return ret
199-
200-
201-
def _format_add_colon_two(
202-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
203-
) -> str:
204-
"""Format the prompt with the add-colon-two style."""
205-
seps = [sep, sep2]
206-
ret = system_message + seps[0]
207-
for i, (role, message) in enumerate(messages):
208-
if message:
209-
ret += role + ": " + message + seps[i % 2]
210-
else:
211-
ret += role + ":"
212-
return ret
213-
214-
215-
def _format_no_colon_single(
216-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
217-
) -> str:
218-
"""Format the prompt with the no-colon-single style."""
219-
ret = system_message
220-
for role, message in messages:
221-
if message:
222-
ret += role + message + sep
223-
else:
224-
ret += role
225-
return ret
226-
227-
228-
def _format_add_colon_space_single(
229-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
230-
) -> str:
231-
"""Format the prompt with the add-colon-space-single style."""
232-
ret = system_message + sep
233-
for role, message in messages:
234-
if message:
235-
ret += role + ": " + message + sep
236-
else:
237-
ret += role + ": " # must be end with a space
238-
return ret
239-
149+
class Jinja2ChatFormatter(ChatFormatter):
150+
def __init__(
151+
self,
152+
template: str,
153+
eos_token: str,
154+
bos_token: str,
155+
):
156+
"""A chat formatter that uses jinja2 templates to format the prompt."""
157+
self.template = template
158+
self.eos_token = eos_token
159+
self.bos_token = bos_token
240160

241-
def _format_chatml(
242-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
243-
) -> str:
244-
"""Format the prompt with the chatml style."""
245-
ret = "" if system_message == "" else system_message + sep + "\n"
246-
for role, message in messages:
247-
if message:
248-
ret += role + "\n" + message + sep + "\n"
249-
else:
250-
ret += role + "\n"
251-
return ret
161+
self._environment = jinja2.Environment(
162+
loader=jinja2.BaseLoader(),
163+
trim_blocks=True,
164+
lstrip_blocks=True,
165+
).from_string(self.template)
252166

167+
def __call__(
168+
self,
169+
*,
170+
messages: List[llama_types.ChatCompletionRequestMessage],
171+
**kwargs: Any,
172+
) -> ChatFormatterResponse:
173+
messages = [
174+
*messages,
175+
llama_types.ChatCompletionRequestAssistantMessage(
176+
role="assistant", content=""
177+
),
178+
]
179+
prompt = self._environment.render(
180+
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token
181+
)
182+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
253183

254-
def _format_chatglm3(
255-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
256-
) -> str:
257-
"""Format the prompt with the chatglm3 style."""
258-
ret = ""
259-
if system_message:
260-
ret += system_message
261-
for role, message in messages:
262-
if message:
263-
ret += role + "\n" + " " + message
264-
else:
265-
ret += role
266-
return ret
184+
def to_chat_handler(self) -> LlamaChatCompletionHandler:
185+
return chat_formatter_to_chat_completion_handler(self)
267186

268187

269188
def _convert_text_completion_to_chat(
@@ -426,16 +345,6 @@ def chat_completion_handler(
426345
return chat_completion_handler
427346

428347

429-
def register_chat_format(name: str):
430-
def decorator(f: ChatFormatter):
431-
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
432-
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
433-
name, chat_completion_handler
434-
)
435-
return f
436-
return decorator
437-
438-
439348
def hf_autotokenizer_to_chat_formatter(
440349
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
441350
) -> ChatFormatter:
@@ -466,7 +375,9 @@ def hf_autotokenizer_to_chat_completion_handler(
466375
return chat_formatter_to_chat_completion_handler(chat_formatter)
467376

468377

469-
def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
378+
def hf_tokenizer_config_to_chat_formatter(
379+
tokenizer_config: Dict[str, Any]
380+
) -> ChatFormatter:
470381
assert isinstance(tokenizer_config, dict)
471382

472383
assert "chat_template" in tokenizer_config
@@ -504,6 +415,7 @@ def format_autotokenizer(
504415
eos_token=eos_token,
505416
)
506417
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
418+
507419
return format_autotokenizer
508420

509421

@@ -514,6 +426,147 @@ def hf_tokenizer_config_to_chat_completion_handler(
514426
return chat_formatter_to_chat_completion_handler(chat_formatter)
515427

516428

429+
### Utility functions for formatting chat prompts ###
430+
431+
432+
def _get_system_message(
433+
messages: List[llama_types.ChatCompletionRequestMessage],
434+
) -> str:
435+
"""Get the first system message."""
436+
for message in messages:
437+
if message["role"] == "system":
438+
return message["content"] or ""
439+
return ""
440+
441+
442+
def _map_roles(
443+
messages: List[llama_types.ChatCompletionRequestMessage],
444+
role_map: Dict[str, str],
445+
) -> List[Tuple[str, Optional[str]]]:
446+
"""Map the message roles."""
447+
output: List[Tuple[str, Optional[str]]] = []
448+
for message in messages:
449+
role = message["role"]
450+
if role in role_map:
451+
content: str | None = (
452+
message["content"] if isinstance(message["content"], str) else None
453+
)
454+
output.append((role_map[role], content))
455+
return output
456+
457+
458+
def _format_llama2(
459+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
460+
) -> str:
461+
"""Format the prompt with the llama2 style."""
462+
seps = [sep, sep2]
463+
ret = system_message + sep
464+
for i, (role, message) in enumerate(messages):
465+
if system_message and i == 0:
466+
m = message or ""
467+
ret += m + seps[i % 2]
468+
elif message:
469+
ret += role + message + " " + seps[i % 2]
470+
else:
471+
ret += role + " "
472+
return ret
473+
474+
475+
def _format_add_colon_single(
476+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
477+
) -> str:
478+
"""Format the prompt with the add-colon-single style."""
479+
ret = system_message + sep
480+
for role, message in messages:
481+
if message:
482+
ret += role + ": " + message + sep
483+
else:
484+
ret += role + ":"
485+
return ret
486+
487+
488+
def _format_add_colon_two(
489+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
490+
) -> str:
491+
"""Format the prompt with the add-colon-two style."""
492+
seps = [sep, sep2]
493+
ret = system_message + seps[0]
494+
for i, (role, message) in enumerate(messages):
495+
if message:
496+
ret += role + ": " + message + seps[i % 2]
497+
else:
498+
ret += role + ":"
499+
return ret
500+
501+
502+
def _format_no_colon_single(
503+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
504+
) -> str:
505+
"""Format the prompt with the no-colon-single style."""
506+
ret = system_message
507+
for role, message in messages:
508+
if message:
509+
ret += role + message + sep
510+
else:
511+
ret += role
512+
return ret
513+
514+
515+
def _format_add_colon_space_single(
516+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
517+
) -> str:
518+
"""Format the prompt with the add-colon-space-single style."""
519+
ret = system_message + sep
520+
for role, message in messages:
521+
if message:
522+
ret += role + ": " + message + sep
523+
else:
524+
ret += role + ": " # must be end with a space
525+
return ret
526+
527+
528+
def _format_chatml(
529+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
530+
) -> str:
531+
"""Format the prompt with the chatml style."""
532+
ret = "" if system_message == "" else system_message + sep + "\n"
533+
for role, message in messages:
534+
if message:
535+
ret += role + "\n" + message + sep + "\n"
536+
else:
537+
ret += role + "\n"
538+
return ret
539+
540+
541+
def _format_chatglm3(
542+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
543+
) -> str:
544+
"""Format the prompt with the chatglm3 style."""
545+
ret = ""
546+
if system_message:
547+
ret += system_message
548+
for role, message in messages:
549+
if message:
550+
ret += role + "\n" + " " + message
551+
else:
552+
ret += role
553+
return ret
554+
555+
556+
### Chat Formats ###
557+
558+
559+
def register_chat_format(name: str):
560+
def decorator(f: ChatFormatter):
561+
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
562+
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
563+
name, chat_completion_handler
564+
)
565+
return f
566+
567+
return decorator
568+
569+
517570
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
518571
# system prompt is "embedded" in the first message
519572
@register_chat_format("llama-2")

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.