Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5ab40e6

Browse filesBrowse files
CISCabetlen
andauthored
feat: Support multiple chat templates - step 1 (abetlen#1396)
* Support multiple chat templates - step 1 As a first step, allow user to to select template from metadata with chat_format parameter in the form of `chat_template.name`. * register chat templates to self.chat_formats instead of globally * Don't expose internal chat handlers yet --------- Co-authored-by: Andrei <abetlen@gmail.com>
1 parent bf66a28 commit 5ab40e6
Copy full SHA for 5ab40e6

File tree

Expand file treeCollapse file tree

1 file changed

+28
-22
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+28
-22
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+28-22Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def __init__(
378378

379379
self.chat_format = chat_format
380380
self.chat_handler = chat_handler
381+
self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = {}
381382

382383
self.draft_model = draft_model
383384

@@ -409,10 +410,33 @@ def __init__(
409410
if self.verbose:
410411
print(f"Model metadata: {self.metadata}", file=sys.stderr)
411412

413+
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
414+
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
415+
416+
eos_token = self._model.token_get_text(eos_token_id)
417+
bos_token = self._model.token_get_text(bos_token_id)
418+
419+
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
420+
template_choices = dict((name[10:], template) for name, template in self.metadata.items() if name.startswith("tokenizer.chat_template."))
421+
422+
if "tokenizer.chat_template" in self.metadata:
423+
template_choices["chat_template.default"] = self.metadata["tokenizer.chat_template"]
424+
425+
if self.verbose and template_choices:
426+
print(f"Available chat formats from metadata: {', '.join(template_choices.keys())}", file=sys.stderr)
427+
428+
for name, template in template_choices.items():
429+
self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
430+
template=template,
431+
eos_token=eos_token,
432+
bos_token=bos_token,
433+
stop_token_ids=[eos_token_id],
434+
).to_chat_handler()
435+
412436
if (
413437
self.chat_format is None
414438
and self.chat_handler is None
415-
and "tokenizer.chat_template" in self.metadata
439+
and "chat_template.default" in template_choices
416440
):
417441
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
418442
self.metadata
@@ -423,30 +447,12 @@ def __init__(
423447
if self.verbose:
424448
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
425449
else:
426-
template = self.metadata["tokenizer.chat_template"]
427-
try:
428-
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
429-
except:
430-
eos_token_id = self.token_eos()
431-
try:
432-
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
433-
except:
434-
bos_token_id = self.token_bos()
435-
436-
eos_token = self._model.token_get_text(eos_token_id)
437-
bos_token = self._model.token_get_text(bos_token_id)
438-
439450
if self.verbose:
440-
print(f"Using gguf chat template: {template}", file=sys.stderr)
451+
print(f"Using gguf chat template: {template_choices['chat_template.default']}", file=sys.stderr)
441452
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
442453
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
443454

444-
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
445-
template=template,
446-
eos_token=eos_token,
447-
bos_token=bos_token,
448-
stop_token_ids=[eos_token_id],
449-
).to_chat_handler()
455+
self.chat_format = "chat_template.default"
450456

451457
if self.chat_format is None and self.chat_handler is None:
452458
self.chat_format = "llama-2"
@@ -1719,7 +1725,7 @@ def create_chat_completion(
17191725
Returns:
17201726
Generated chat completion or a stream of chat completion chunks.
17211727
"""
1722-
handler = self.chat_handler or llama_chat_format.get_chat_completion_handler(
1728+
handler = self.chat_handler or self._chat_handlers.get(self.chat_format) or llama_chat_format.get_chat_completion_handler(
17231729
self.chat_format
17241730
)
17251731
return handler(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.