21
21
import diskcache
22
22
import ctypes
23
23
24
- from llama_cpp .llama_chat_templates import ChatCompletionFormat , DefaultChatCompletionFormat
25
-
26
24
from . import llama_cpp
27
25
from .llama_types import *
28
26
from .llama_grammar import LlamaGrammar
@@ -240,7 +238,7 @@ def __init__(
240
238
lora_base : Optional [str ] = None ,
241
239
lora_path : Optional [str ] = None ,
242
240
numa : bool = False ,
243
- chat_completion_template : Optional [ChatCompletionFormat ] = None ,
241
+ chat_completion_template : Optional [" ChatCompletionFormat" ] = None ,
244
242
verbose : bool = True ,
245
243
** kwargs , # type: ignore
246
244
):
@@ -323,7 +321,9 @@ def __init__(
323
321
self .last_n_tokens_size = last_n_tokens_size
324
322
self .n_batch = min (n_ctx , n_batch )
325
323
326
- self .chat_completion_template = chat_completion_template or DefaultChatCompletionFormat ()
324
+ self .chat_completion_template = (
325
+ chat_completion_template or DefaultChatCompletionFormat ()
326
+ )
327
327
328
328
self .cache : Optional [BaseLlamaCache ] = None
329
329
@@ -1783,3 +1783,89 @@ def decode(self, tokens: List[int]) -> str:
1783
1783
@classmethod
1784
1784
def from_ggml_file (cls , path : str ) -> "LlamaTokenizer" :
1785
1785
return cls (Llama (model_path = path , vocab_only = True ))
1786
+
1787
+
1788
+ class ChatCompletionFormat (ABC ):
1789
+ """Base class for chat completion templates."""
1790
+
1791
+ @abstractmethod
1792
+ def create_chat_completion (
1793
+ self ,
1794
+ llama : Llama ,
1795
+ messages : List [ChatCompletionMessage ],
1796
+ functions : Optional [List [ChatCompletionFunction ]] = None ,
1797
+ function_call : Optional [Union [str , ChatCompletionFunctionCall ]] = None ,
1798
+ temperature : float = 0.2 ,
1799
+ top_p : float = 0.95 ,
1800
+ top_k : int = 40 ,
1801
+ stream : bool = False ,
1802
+ stop : Optional [Union [str , List [str ]]] = [],
1803
+ max_tokens : int = 256 ,
1804
+ presence_penalty : float = 0.0 ,
1805
+ frequency_penalty : float = 0.0 ,
1806
+ repeat_penalty : float = 1.1 ,
1807
+ tfs_z : float = 1.0 ,
1808
+ mirostat_mode : int = 0 ,
1809
+ mirostat_tau : float = 5.0 ,
1810
+ mirostat_eta : float = 0.1 ,
1811
+ model : Optional [str ] = None ,
1812
+ logits_processor : Optional [LogitsProcessorList ] = None ,
1813
+ grammar : Optional [LlamaGrammar ] = None ,
1814
+ ) -> Union [Completion , Iterator [CompletionChunk ]]:
1815
+ raise NotImplementedError
1816
+
1817
+
1818
+ class DefaultChatCompletionFormat (ABC ):
1819
+ """Base class for chat completion templates."""
1820
+
1821
+ def create_chat_completion (
1822
+ self ,
1823
+ llama : Llama ,
1824
+ messages : List [ChatCompletionMessage ],
1825
+ functions : Optional [List [ChatCompletionFunction ]] = None ,
1826
+ function_call : Optional [Union [str , ChatCompletionFunctionCall ]] = None ,
1827
+ temperature : float = 0.2 ,
1828
+ top_p : float = 0.95 ,
1829
+ top_k : int = 40 ,
1830
+ stream : bool = False ,
1831
+ stop : Optional [Union [str , List [str ]]] = [],
1832
+ max_tokens : int = 256 ,
1833
+ presence_penalty : float = 0.0 ,
1834
+ frequency_penalty : float = 0.0 ,
1835
+ repeat_penalty : float = 1.1 ,
1836
+ tfs_z : float = 1.0 ,
1837
+ mirostat_mode : int = 0 ,
1838
+ mirostat_tau : float = 5.0 ,
1839
+ mirostat_eta : float = 0.1 ,
1840
+ model : Optional [str ] = None ,
1841
+ logits_processor : Optional [LogitsProcessorList ] = None ,
1842
+ grammar : Optional [LlamaGrammar ] = None ,
1843
+ ) -> Union [Completion , Iterator [CompletionChunk ]]:
1844
+ stop = (
1845
+ stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
1846
+ )
1847
+ chat_history = "" .join (
1848
+ f'### { "Human" if message ["role" ] == "user" else "Assistant" } :{ message ["content" ]} '
1849
+ for message in messages
1850
+ )
1851
+ PROMPT = chat_history + "### Assistant:"
1852
+ PROMPT_STOP = ["### Assistant:" , "### Human:" ]
1853
+ return llama .create_completion (
1854
+ prompt = PROMPT ,
1855
+ stop = PROMPT_STOP + stop ,
1856
+ temperature = temperature ,
1857
+ top_p = top_p ,
1858
+ top_k = top_k ,
1859
+ stream = stream ,
1860
+ max_tokens = max_tokens ,
1861
+ repeat_penalty = repeat_penalty ,
1862
+ presence_penalty = presence_penalty ,
1863
+ frequency_penalty = frequency_penalty ,
1864
+ tfs_z = tfs_z ,
1865
+ mirostat_mode = mirostat_mode ,
1866
+ mirostat_tau = mirostat_tau ,
1867
+ mirostat_eta = mirostat_eta ,
1868
+ model = model ,
1869
+ logits_processor = logits_processor ,
1870
+ grammar = grammar ,
1871
+ )
0 commit comments