Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b30b9c3

Browse filesBrowse files
committed
Add JSON mode support. Closes abetlen#881
1 parent 4852a6a commit b30b9c3
Copy full SHA for b30b9c3

File tree

Expand file treeCollapse file tree

4 files changed

+116
-39
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+116
-39
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,7 @@ def create_chat_completion(
19011901
stream: bool = False,
19021902
stop: Optional[Union[str, List[str]]] = [],
19031903
seed: Optional[int] = None,
1904+
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
19041905
max_tokens: int = 256,
19051906
presence_penalty: float = 0.0,
19061907
frequency_penalty: float = 0.0,
@@ -1946,6 +1947,7 @@ def create_chat_completion(
19461947
stream=stream,
19471948
stop=stop,
19481949
seed=seed,
1950+
response_format=response_format,
19491951
max_tokens=max_tokens,
19501952
presence_penalty=presence_penalty,
19511953
frequency_penalty=frequency_penalty,

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+106-38Lines changed: 106 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import dataclasses
66
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
77

8-
import llama_cpp.llama_types as llama_types
98
import llama_cpp.llama as llama
9+
import llama_cpp.llama_types as llama_types
10+
import llama_cpp.llama_grammar as llama_grammar
1011

1112

1213
class LlamaChatCompletionHandler(Protocol):
@@ -25,6 +26,9 @@ def __call__(
2526
stream: bool = False,
2627
stop: Optional[Union[str, List[str]]] = [],
2728
seed: Optional[int] = None,
29+
response_format: Optional[
30+
llama_types.ChatCompletionRequestResponseFormat
31+
] = None,
2832
max_tokens: int = 256,
2933
presence_penalty: float = 0.0,
3034
frequency_penalty: float = 0.0,
@@ -37,7 +41,10 @@ def __call__(
3741
logits_processor: Optional[llama.LogitsProcessorList] = None,
3842
grammar: Optional[llama.LlamaGrammar] = None,
3943
**kwargs, # type: ignore
40-
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
44+
) -> Union[
45+
llama_types.CreateChatCompletionResponse,
46+
Iterator[llama_types.CreateChatCompletionStreamResponse],
47+
]:
4148
...
4249

4350

@@ -169,6 +176,7 @@ class ChatFormatterResponse:
169176
class ChatFormatter(Protocol):
170177
def __call__(
171178
self,
179+
*,
172180
messages: List[llama_types.ChatCompletionRequestMessage],
173181
**kwargs: Any,
174182
) -> ChatFormatterResponse:
@@ -264,17 +272,24 @@ def _convert_completion_to_chat(
264272
def register_chat_format(name: str):
265273
def decorator(f: ChatFormatter):
266274
def basic_create_chat_completion(
275+
*,
267276
llama: llama.Llama,
268277
messages: List[llama_types.ChatCompletionRequestMessage],
269278
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
270279
function_call: Optional[
271-
Union[str, llama_types.ChatCompletionFunctionCall]
280+
llama_types.ChatCompletionRequestFunctionCall
272281
] = None,
282+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
283+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
273284
temperature: float = 0.2,
274285
top_p: float = 0.95,
275286
top_k: int = 40,
276287
stream: bool = False,
277288
stop: Optional[Union[str, List[str]]] = [],
289+
seed: Optional[int] = None,
290+
response_format: Optional[
291+
llama_types.ChatCompletionRequestResponseFormat
292+
] = None,
278293
max_tokens: int = 256,
279294
presence_penalty: float = 0.0,
280295
frequency_penalty: float = 0.0,
@@ -286,8 +301,10 @@ def basic_create_chat_completion(
286301
model: Optional[str] = None,
287302
logits_processor: Optional[llama.LogitsProcessorList] = None,
288303
grammar: Optional[llama.LlamaGrammar] = None,
304+
**kwargs, # type: ignore
289305
) -> Union[
290-
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
306+
llama_types.CreateChatCompletionResponse,
307+
Iterator[llama_types.CreateChatCompletionStreamResponse],
291308
]:
292309
result = f(
293310
messages=messages,
@@ -299,6 +316,10 @@ def basic_create_chat_completion(
299316
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
300317
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
301318
stop = stop + rstop
319+
320+
if response_format is not None and response_format["type"] == "json_object":
321+
print("hello world")
322+
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
302323

303324
completion_or_chunks = llama.create_completion(
304325
prompt=prompt,
@@ -307,6 +328,7 @@ def basic_create_chat_completion(
307328
top_k=top_k,
308329
stream=stream,
309330
stop=stop,
331+
seed=seed,
310332
max_tokens=max_tokens,
311333
presence_penalty=presence_penalty,
312334
frequency_penalty=frequency_penalty,
@@ -319,7 +341,7 @@ def basic_create_chat_completion(
319341
logits_processor=logits_processor,
320342
grammar=grammar,
321343
)
322-
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
344+
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
323345

324346
register_chat_completion_handler(name)(basic_create_chat_completion)
325347
return f
@@ -727,7 +749,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
727749

728750
assert "usage" in completion
729751
assert isinstance(function_call, str)
730-
assert stream is False # TODO: support stream mode
752+
assert stream is False # TODO: support stream mode
731753

732754
return llama_types.CreateChatCompletionResponse(
733755
id="chat" + completion["id"],
@@ -759,7 +781,9 @@ def __init__(self, clip_model_path: str):
759781
self._llava_cpp = llava_cpp
760782
self.clip_model_path = clip_model_path
761783

762-
self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0)
784+
self.clip_ctx = self._llava_cpp.clip_model_load(
785+
self.clip_model_path.encode(), 0
786+
)
763787

764788
def __del__(self):
765789
if self.clip_ctx is not None:
@@ -805,64 +829,108 @@ def __call__(
805829
logits_processor: Optional[llama.LogitsProcessorList] = None,
806830
grammar: Optional[llama.LlamaGrammar] = None,
807831
**kwargs, # type: ignore
808-
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
809-
assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava
832+
) -> Union[
833+
llama_types.CreateChatCompletionResponse,
834+
Iterator[llama_types.CreateChatCompletionStreamResponse],
835+
]:
836+
assert (
837+
llama.context_params.logits_all is True
838+
) # BUG: logits_all=True is required for llava
810839
assert self.clip_ctx is not None
811840
system_prompt = _get_system_message(messages)
812-
system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
813-
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
841+
system_prompt = (
842+
system_prompt
843+
if system_prompt != ""
844+
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
845+
)
846+
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
814847
user_role = "\nUSER:"
815848
assistant_role = "\nASSISTANT:"
816849
llama.reset()
817850
llama.eval(llama.tokenize(system_prompt.encode("utf8"), add_bos=True))
818851
for message in messages:
819852
if message["role"] == "user" and message["content"] is not None:
820853
if isinstance(message["content"], str):
821-
llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False))
854+
llama.eval(
855+
llama.tokenize(
856+
f"{user_role} {message['content']}".encode("utf8"),
857+
add_bos=False,
858+
)
859+
)
822860
else:
823861
assert isinstance(message["content"], list)
824-
llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False))
862+
llama.eval(
863+
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
864+
)
825865
for content in message["content"]:
826866
if content["type"] == "text":
827-
llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False))
867+
llama.eval(
868+
llama.tokenize(
869+
f"{content['text']}".encode("utf8"), add_bos=False
870+
)
871+
)
828872
if content["type"] == "image_url":
829-
image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"])
873+
image_bytes = (
874+
self.load_image(content["image_url"]["url"])
875+
if isinstance(content["image_url"], dict)
876+
else self.load_image(content["image_url"])
877+
)
830878
import array
831-
data_array = array.array('B', image_bytes)
832-
c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array)
833-
embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes))
879+
880+
data_array = array.array("B", image_bytes)
881+
c_ubyte_ptr = (
882+
ctypes.c_ubyte * len(data_array)
883+
).from_buffer(data_array)
884+
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
885+
ctx_clip=self.clip_ctx,
886+
n_threads=llama.context_params.n_threads,
887+
image_bytes=c_ubyte_ptr,
888+
image_bytes_length=len(image_bytes),
889+
)
834890
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
835891
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
836892
try:
837893
n_past = ctypes.c_int(llama.n_tokens)
838894
n_past_p = ctypes.pointer(n_past)
839-
self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p)
895+
self._llava_cpp.llava_eval_image_embed(
896+
ctx_llama=llama.ctx,
897+
embed=embed,
898+
n_batch=llama.n_batch,
899+
n_past=n_past_p,
900+
)
840901
assert llama.n_ctx() >= n_past.value
841902
llama.n_tokens = n_past.value
842903
finally:
843904
self._llava_cpp.llava_image_embed_free(embed)
844905
if message["role"] == "assistant" and message["content"] is not None:
845-
llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False))
906+
llama.eval(
907+
llama.tokenize(
908+
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
909+
)
910+
)
846911
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
847912

848913
prompt = llama._input_ids.tolist()
849914

850-
return _convert_completion_to_chat(llama.create_completion(
851-
prompt=prompt,
852-
temperature=temperature,
853-
top_p=top_p,
854-
top_k=top_k,
915+
return _convert_completion_to_chat(
916+
llama.create_completion(
917+
prompt=prompt,
918+
temperature=temperature,
919+
top_p=top_p,
920+
top_k=top_k,
921+
stream=stream,
922+
stop=stop,
923+
max_tokens=max_tokens,
924+
presence_penalty=presence_penalty,
925+
frequency_penalty=frequency_penalty,
926+
repeat_penalty=repeat_penalty,
927+
tfs_z=tfs_z,
928+
mirostat_mode=mirostat_mode,
929+
mirostat_tau=mirostat_tau,
930+
mirostat_eta=mirostat_eta,
931+
model=model,
932+
logits_processor=logits_processor,
933+
grammar=grammar,
934+
),
855935
stream=stream,
856-
stop=stop,
857-
max_tokens=max_tokens,
858-
presence_penalty=presence_penalty,
859-
frequency_penalty=frequency_penalty,
860-
repeat_penalty=repeat_penalty,
861-
tfs_z=tfs_z,
862-
mirostat_mode=mirostat_mode,
863-
mirostat_tau=mirostat_tau,
864-
mirostat_eta=mirostat_eta,
865-
model=model,
866-
logits_processor=logits_processor,
867-
grammar=grammar,
868-
), stream=stream)
936+
)

‎llama_cpp/llama_types.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_types.py
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ class ChatCompletionFunctionCallOption(TypedDict):
152152
name: str
153153

154154

155+
class ChatCompletionRequestResponseFormat(TypedDict):
156+
type: Literal["text", "json_object"]
157+
158+
155159
class ChatCompletionRequestMessageContentPartText(TypedDict):
156160
type: Literal["text"]
157161
text: str
@@ -241,7 +245,7 @@ class ChatCompletionRequestFunctionCallOption(TypedDict):
241245
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
242246
]
243247

244-
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
248+
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
245249

246250

247251
class ChatCompletionToolFunction(TypedDict):

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,9 @@ class CreateChatCompletionRequest(BaseModel):
792792
frequency_penalty: Optional[float] = frequency_penalty_field
793793
logit_bias: Optional[Dict[str, float]] = Field(None)
794794
seed: Optional[int] = Field(None)
795+
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
796+
default=None,
797+
)
795798

796799
# ignored or currently unsupported
797800
model: Optional[str] = model_field

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.