24
24
from . import llama_cpp
25
25
from .llama_types import *
26
26
from .llama_grammar import LlamaGrammar
27
+ from . import llama_chat_format
27
28
28
29
import numpy as np
29
30
import numpy .typing as npt
@@ -243,6 +244,8 @@ def __init__(
243
244
lora_path : Optional [str ] = None ,
244
245
# Backend Params
245
246
numa : bool = False ,
247
+ # Chat Format Params
248
+ chat_format : str = "llama-2" ,
246
249
# Misc
247
250
verbose : bool = True ,
248
251
# Extra Params
@@ -273,6 +276,7 @@ def __init__(
273
276
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
274
277
lora_path: Path to a LoRA file to apply to the model.
275
278
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
279
+ chat_format: String specifying the chat format to use when calling create_chat_completion.
276
280
verbose: Print verbose output to stderr.
277
281
kwargs: Unused keyword arguments (for additional backwards compatibility).
278
282
@@ -387,6 +391,8 @@ def __init__(
387
391
388
392
if self .verbose :
389
393
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
394
+
395
+ self .chat_format = chat_format
390
396
391
397
self ._n_vocab = self .n_vocab ()
392
398
self ._n_ctx = self .n_ctx ()
@@ -1578,7 +1584,7 @@ def _convert_completion_to_chat(
1578
1584
1579
1585
def create_chat_completion (
1580
1586
self ,
1581
- messages : List [ChatCompletionMessage ],
1587
+ messages : List [ChatCompletionRequestMessage ],
1582
1588
functions : Optional [List [ChatCompletionFunction ]] = None ,
1583
1589
function_call : Optional [Union [str , ChatCompletionFunctionCall ]] = None ,
1584
1590
temperature : float = 0.2 ,
@@ -1613,11 +1619,19 @@ def create_chat_completion(
1613
1619
Returns:
1614
1620
Generated chat completion or a stream of chat completion chunks.
1615
1621
"""
1616
- completion_or_chunks = self .chat_completion_template .create_chat_completion (
1617
- self ,
1622
+
1623
+ format = llama_chat_format .get_chat_format (self .chat_format )
1624
+ result = format (
1618
1625
messages = messages ,
1619
- functions = functions ,
1620
- function_call = function_call ,
1626
+ )
1627
+ prompt = result .prompt
1628
+ if result .stop is not None :
1629
+ stop = [] if stop is None else [stop ] if isinstance (stop , str ) else stop
1630
+ rstop = result .stop if isinstance (result .stop , list ) else [result .stop ]
1631
+ stop = stop + rstop
1632
+
1633
+ completion_or_chunks = self .create_completion (
1634
+ prompt = prompt ,
1621
1635
temperature = temperature ,
1622
1636
top_p = top_p ,
1623
1637
top_k = top_k ,
@@ -1675,6 +1689,8 @@ def __getstate__(self):
1675
1689
lora_path = self .lora_path ,
1676
1690
# Backend Params
1677
1691
numa = self .numa ,
1692
+ # Chat Format Params
1693
+ chat_format = self .chat_format ,
1678
1694
# Misc
1679
1695
verbose = self .verbose ,
1680
1696
)
@@ -1708,6 +1724,8 @@ def __setstate__(self, state):
1708
1724
lora_path = state ["lora_path" ],
1709
1725
# Backend Params
1710
1726
numa = state ["numa" ],
1727
+ # Chat Format Params
1728
+ chat_format = state ["chat_format" ],
1711
1729
# Misc
1712
1730
verbose = state ["verbose" ],
1713
1731
)
@@ -1821,89 +1839,3 @@ def decode(self, tokens: List[int]) -> str:
1821
1839
@classmethod
1822
1840
def from_ggml_file (cls , path : str ) -> "LlamaTokenizer" :
1823
1841
return cls (Llama (model_path = path , vocab_only = True ))
1824
-
1825
-
1826
- class ChatCompletionFormat (ABC ):
1827
- """Base class for chat completion templates."""
1828
-
1829
- @abstractmethod
1830
- def create_chat_completion (
1831
- self ,
1832
- llama : Llama ,
1833
- messages : List [ChatCompletionMessage ],
1834
- functions : Optional [List [ChatCompletionFunction ]] = None ,
1835
- function_call : Optional [Union [str , ChatCompletionFunctionCall ]] = None ,
1836
- temperature : float = 0.2 ,
1837
- top_p : float = 0.95 ,
1838
- top_k : int = 40 ,
1839
- stream : bool = False ,
1840
- stop : Optional [Union [str , List [str ]]] = [],
1841
- max_tokens : int = 256 ,
1842
- presence_penalty : float = 0.0 ,
1843
- frequency_penalty : float = 0.0 ,
1844
- repeat_penalty : float = 1.1 ,
1845
- tfs_z : float = 1.0 ,
1846
- mirostat_mode : int = 0 ,
1847
- mirostat_tau : float = 5.0 ,
1848
- mirostat_eta : float = 0.1 ,
1849
- model : Optional [str ] = None ,
1850
- logits_processor : Optional [LogitsProcessorList ] = None ,
1851
- grammar : Optional [LlamaGrammar ] = None ,
1852
- ) -> Union [Completion , Iterator [CompletionChunk ]]:
1853
- raise NotImplementedError
1854
-
1855
-
1856
- class DefaultChatCompletionFormat (ABC ):
1857
- """Base class for chat completion templates."""
1858
-
1859
- def create_chat_completion (
1860
- self ,
1861
- llama : Llama ,
1862
- messages : List [ChatCompletionMessage ],
1863
- functions : Optional [List [ChatCompletionFunction ]] = None ,
1864
- function_call : Optional [Union [str , ChatCompletionFunctionCall ]] = None ,
1865
- temperature : float = 0.2 ,
1866
- top_p : float = 0.95 ,
1867
- top_k : int = 40 ,
1868
- stream : bool = False ,
1869
- stop : Optional [Union [str , List [str ]]] = [],
1870
- max_tokens : int = 256 ,
1871
- presence_penalty : float = 0.0 ,
1872
- frequency_penalty : float = 0.0 ,
1873
- repeat_penalty : float = 1.1 ,
1874
- tfs_z : float = 1.0 ,
1875
- mirostat_mode : int = 0 ,
1876
- mirostat_tau : float = 5.0 ,
1877
- mirostat_eta : float = 0.1 ,
1878
- model : Optional [str ] = None ,
1879
- logits_processor : Optional [LogitsProcessorList ] = None ,
1880
- grammar : Optional [LlamaGrammar ] = None ,
1881
- ) -> Union [Completion , Iterator [CompletionChunk ]]:
1882
- stop = (
1883
- stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
1884
- )
1885
- chat_history = "" .join (
1886
- f'### { "Human" if message ["role" ] == "user" else "Assistant" } :{ message ["content" ]} '
1887
- for message in messages
1888
- )
1889
- PROMPT = chat_history + "### Assistant:"
1890
- PROMPT_STOP = ["### Assistant:" , "### Human:" ]
1891
- return llama .create_completion (
1892
- prompt = PROMPT ,
1893
- stop = PROMPT_STOP + stop ,
1894
- temperature = temperature ,
1895
- top_p = top_p ,
1896
- top_k = top_k ,
1897
- stream = stream ,
1898
- max_tokens = max_tokens ,
1899
- repeat_penalty = repeat_penalty ,
1900
- presence_penalty = presence_penalty ,
1901
- frequency_penalty = frequency_penalty ,
1902
- tfs_z = tfs_z ,
1903
- mirostat_mode = mirostat_mode ,
1904
- mirostat_tau = mirostat_tau ,
1905
- mirostat_eta = mirostat_eta ,
1906
- model = model ,
1907
- logits_processor = logits_processor ,
1908
- grammar = grammar ,
1909
- )
0 commit comments