Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3bca770

Browse filesBrowse files
authored
Configurable Chat Formats (abetlen#711)
* Add configurable default chat completion format. * Remove chat_template file to avoid circular import * Update llama_types * Add chat format
1 parent a945404 commit 3bca770
Copy full SHA for 3bca770

File tree

Expand file treeCollapse file tree

2 files changed

+330
-19
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+330
-19
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+38-19Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from . import llama_cpp
2525
from .llama_types import *
2626
from .llama_grammar import LlamaGrammar
27+
from . import llama_chat_format
2728

2829
import numpy as np
2930
import numpy.typing as npt
@@ -243,6 +244,8 @@ def __init__(
243244
lora_path: Optional[str] = None,
244245
# Backend Params
245246
numa: bool = False,
247+
# Chat Format Params
248+
chat_format: str = "llama-2",
246249
# Misc
247250
verbose: bool = True,
248251
# Extra Params
@@ -273,6 +276,7 @@ def __init__(
273276
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
274277
lora_path: Path to a LoRA file to apply to the model.
275278
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
279+
chat_format: String specifying the chat format to use when calling create_chat_completion.
276280
verbose: Print verbose output to stderr.
277281
kwargs: Unused keyword arguments (for additional backwards compatibility).
278282
@@ -388,6 +392,8 @@ def __init__(
388392

389393
if self.verbose:
390394
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
395+
396+
self.chat_format = chat_format
391397

392398
self._n_vocab = self.n_vocab()
393399
self._n_ctx = self.n_ctx()
@@ -1565,9 +1571,21 @@ def _convert_text_completion_chunks_to_chat(
15651571
],
15661572
}
15671573

1574+
def _convert_completion_to_chat(
1575+
self,
1576+
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
1577+
stream: bool = False,
1578+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
1579+
if stream:
1580+
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
1581+
return self._convert_text_completion_chunks_to_chat(chunks)
1582+
else:
1583+
completion: Completion = completion_or_chunks # type: ignore
1584+
return self._convert_text_completion_to_chat(completion)
1585+
15681586
def create_chat_completion(
15691587
self,
1570-
messages: List[ChatCompletionMessage],
1588+
messages: List[ChatCompletionRequestMessage],
15711589
functions: Optional[List[ChatCompletionFunction]] = None,
15721590
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
15731591
temperature: float = 0.2,
@@ -1602,26 +1620,28 @@ def create_chat_completion(
16021620
Returns:
16031621
Generated chat completion or a stream of chat completion chunks.
16041622
"""
1605-
stop = (
1606-
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1607-
)
1608-
chat_history = "".join(
1609-
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
1610-
for message in messages
1623+
1624+
format = llama_chat_format.get_chat_format(self.chat_format)
1625+
result = format(
1626+
messages=messages,
16111627
)
1612-
PROMPT = chat_history + "### Assistant:"
1613-
PROMPT_STOP = ["### Assistant:", "### Human:"]
1614-
completion_or_chunks = self(
1615-
prompt=PROMPT,
1616-
stop=PROMPT_STOP + stop,
1628+
prompt = result.prompt
1629+
if result.stop is not None:
1630+
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
1631+
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
1632+
stop = stop + rstop
1633+
1634+
completion_or_chunks = self.create_completion(
1635+
prompt=prompt,
16171636
temperature=temperature,
16181637
top_p=top_p,
16191638
top_k=top_k,
16201639
stream=stream,
1640+
stop=stop,
16211641
max_tokens=max_tokens,
1622-
repeat_penalty=repeat_penalty,
16231642
presence_penalty=presence_penalty,
16241643
frequency_penalty=frequency_penalty,
1644+
repeat_penalty=repeat_penalty,
16251645
tfs_z=tfs_z,
16261646
mirostat_mode=mirostat_mode,
16271647
mirostat_tau=mirostat_tau,
@@ -1630,12 +1650,7 @@ def create_chat_completion(
16301650
logits_processor=logits_processor,
16311651
grammar=grammar,
16321652
)
1633-
if stream:
1634-
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
1635-
return self._convert_text_completion_chunks_to_chat(chunks)
1636-
else:
1637-
completion: Completion = completion_or_chunks # type: ignore
1638-
return self._convert_text_completion_to_chat(completion)
1653+
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
16391654

16401655
def __del__(self):
16411656
if hasattr(self, "model") and self.model is not None:
@@ -1675,6 +1690,8 @@ def __getstate__(self):
16751690
lora_path=self.lora_path,
16761691
# Backend Params
16771692
numa=self.numa,
1693+
# Chat Format Params
1694+
chat_format=self.chat_format,
16781695
# Misc
16791696
verbose=self.verbose,
16801697
)
@@ -1708,6 +1725,8 @@ def __setstate__(self, state):
17081725
lora_path=state["lora_path"],
17091726
# Backend Params
17101727
numa=state["numa"],
1728+
# Chat Format Params
1729+
chat_format=state["chat_format"],
17111730
# Misc
17121731
verbose=state["verbose"],
17131732
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.