Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ae47e4f

Browse filesBrowse files
committed
Add chat format
1 parent 9c68382 commit ae47e4f
Copy full SHA for ae47e4f

File tree

Expand file treeCollapse file tree

2 files changed

+315
-91
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+315
-91
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+23-91Lines changed: 23 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from . import llama_cpp
2525
from .llama_types import *
2626
from .llama_grammar import LlamaGrammar
27+
from . import llama_chat_format
2728

2829
import numpy as np
2930
import numpy.typing as npt
@@ -243,6 +244,8 @@ def __init__(
243244
lora_path: Optional[str] = None,
244245
# Backend Params
245246
numa: bool = False,
247+
# Chat Format Params
248+
chat_format: str = "llama-2",
246249
# Misc
247250
verbose: bool = True,
248251
# Extra Params
@@ -273,6 +276,7 @@ def __init__(
273276
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
274277
lora_path: Path to a LoRA file to apply to the model.
275278
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
279+
chat_format: String specifying the chat format to use when calling create_chat_completion.
276280
verbose: Print verbose output to stderr.
277281
kwargs: Unused keyword arguments (for additional backwards compatibility).
278282
@@ -387,6 +391,8 @@ def __init__(
387391

388392
if self.verbose:
389393
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
394+
395+
self.chat_format = chat_format
390396

391397
self._n_vocab = self.n_vocab()
392398
self._n_ctx = self.n_ctx()
@@ -1578,7 +1584,7 @@ def _convert_completion_to_chat(
15781584

15791585
def create_chat_completion(
15801586
self,
1581-
messages: List[ChatCompletionMessage],
1587+
messages: List[ChatCompletionRequestMessage],
15821588
functions: Optional[List[ChatCompletionFunction]] = None,
15831589
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
15841590
temperature: float = 0.2,
@@ -1613,11 +1619,19 @@ def create_chat_completion(
16131619
Returns:
16141620
Generated chat completion or a stream of chat completion chunks.
16151621
"""
1616-
completion_or_chunks = self.chat_completion_template.create_chat_completion(
1617-
self,
1622+
1623+
format = llama_chat_format.get_chat_format(self.chat_format)
1624+
result = format(
16181625
messages=messages,
1619-
functions=functions,
1620-
function_call=function_call,
1626+
)
1627+
prompt = result.prompt
1628+
if result.stop is not None:
1629+
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
1630+
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
1631+
stop = stop + rstop
1632+
1633+
completion_or_chunks = self.create_completion(
1634+
prompt=prompt,
16211635
temperature=temperature,
16221636
top_p=top_p,
16231637
top_k=top_k,
@@ -1675,6 +1689,8 @@ def __getstate__(self):
16751689
lora_path=self.lora_path,
16761690
# Backend Params
16771691
numa=self.numa,
1692+
# Chat Format Params
1693+
chat_format=self.chat_format,
16781694
# Misc
16791695
verbose=self.verbose,
16801696
)
@@ -1708,6 +1724,8 @@ def __setstate__(self, state):
17081724
lora_path=state["lora_path"],
17091725
# Backend Params
17101726
numa=state["numa"],
1727+
# Chat Format Params
1728+
chat_format=state["chat_format"],
17111729
# Misc
17121730
verbose=state["verbose"],
17131731
)
@@ -1821,89 +1839,3 @@ def decode(self, tokens: List[int]) -> str:
18211839
@classmethod
18221840
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
18231841
return cls(Llama(model_path=path, vocab_only=True))
1824-
1825-
1826-
class ChatCompletionFormat(ABC):
1827-
"""Base class for chat completion templates."""
1828-
1829-
@abstractmethod
1830-
def create_chat_completion(
1831-
self,
1832-
llama: Llama,
1833-
messages: List[ChatCompletionMessage],
1834-
functions: Optional[List[ChatCompletionFunction]] = None,
1835-
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
1836-
temperature: float = 0.2,
1837-
top_p: float = 0.95,
1838-
top_k: int = 40,
1839-
stream: bool = False,
1840-
stop: Optional[Union[str, List[str]]] = [],
1841-
max_tokens: int = 256,
1842-
presence_penalty: float = 0.0,
1843-
frequency_penalty: float = 0.0,
1844-
repeat_penalty: float = 1.1,
1845-
tfs_z: float = 1.0,
1846-
mirostat_mode: int = 0,
1847-
mirostat_tau: float = 5.0,
1848-
mirostat_eta: float = 0.1,
1849-
model: Optional[str] = None,
1850-
logits_processor: Optional[LogitsProcessorList] = None,
1851-
grammar: Optional[LlamaGrammar] = None,
1852-
) -> Union[Completion, Iterator[CompletionChunk]]:
1853-
raise NotImplementedError
1854-
1855-
1856-
class DefaultChatCompletionFormat(ABC):
1857-
"""Base class for chat completion templates."""
1858-
1859-
def create_chat_completion(
1860-
self,
1861-
llama: Llama,
1862-
messages: List[ChatCompletionMessage],
1863-
functions: Optional[List[ChatCompletionFunction]] = None,
1864-
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
1865-
temperature: float = 0.2,
1866-
top_p: float = 0.95,
1867-
top_k: int = 40,
1868-
stream: bool = False,
1869-
stop: Optional[Union[str, List[str]]] = [],
1870-
max_tokens: int = 256,
1871-
presence_penalty: float = 0.0,
1872-
frequency_penalty: float = 0.0,
1873-
repeat_penalty: float = 1.1,
1874-
tfs_z: float = 1.0,
1875-
mirostat_mode: int = 0,
1876-
mirostat_tau: float = 5.0,
1877-
mirostat_eta: float = 0.1,
1878-
model: Optional[str] = None,
1879-
logits_processor: Optional[LogitsProcessorList] = None,
1880-
grammar: Optional[LlamaGrammar] = None,
1881-
) -> Union[Completion, Iterator[CompletionChunk]]:
1882-
stop = (
1883-
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1884-
)
1885-
chat_history = "".join(
1886-
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
1887-
for message in messages
1888-
)
1889-
PROMPT = chat_history + "### Assistant:"
1890-
PROMPT_STOP = ["### Assistant:", "### Human:"]
1891-
return llama.create_completion(
1892-
prompt=PROMPT,
1893-
stop=PROMPT_STOP + stop,
1894-
temperature=temperature,
1895-
top_p=top_p,
1896-
top_k=top_k,
1897-
stream=stream,
1898-
max_tokens=max_tokens,
1899-
repeat_penalty=repeat_penalty,
1900-
presence_penalty=presence_penalty,
1901-
frequency_penalty=frequency_penalty,
1902-
tfs_z=tfs_z,
1903-
mirostat_mode=mirostat_mode,
1904-
mirostat_tau=mirostat_tau,
1905-
mirostat_eta=mirostat_eta,
1906-
model=model,
1907-
logits_processor=logits_processor,
1908-
grammar=grammar,
1909-
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.