Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 94fe4bc

Browse filesBrowse files
committed
Add function calling support
1 parent fd55c29 commit 94fe4bc
Copy full SHA for 94fe4bc

File tree

Expand file treeCollapse file tree

1 file changed

+72
-23
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

1 file changed

+72
-23
lines changed
Open diff view settings
Collapse file

‎llama_cpp/llama_chat_format.py‎

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+72-23Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,7 @@ def __call__(
22322232
typical_p: float = 1.0,
22332233
stream: bool = False,
22342234
stop: Optional[Union[str, List[str]]] = [],
2235+
seed: Optional[int] = None,
22352236
response_format: Optional[
22362237
llama_types.ChatCompletionRequestResponseFormat
22372238
] = None,
@@ -2246,6 +2247,9 @@ def __call__(
22462247
model: Optional[str] = None,
22472248
logits_processor: Optional[llama.LogitsProcessorList] = None,
22482249
grammar: Optional[llama.LlamaGrammar] = None,
2250+
logit_bias: Optional[Dict[str, float]] = None,
2251+
logprobs: Optional[bool] = None,
2252+
top_logprobs: Optional[int] = None,
22492253
**kwargs, # type: ignore
22502254
) -> Union[
22512255
llama_types.CreateChatCompletionResponse,
@@ -2309,32 +2313,77 @@ def free_embed():
23092313
if response_format is not None and response_format["type"] == "json_object":
23102314
grammar = _grammar_for_response_format(response_format)
23112315

2312-
# TODO: Add function call support
2316+
# Convert legacy functions to tools
2317+
if functions is not None:
2318+
tools = [
2319+
{
2320+
"type": "function",
2321+
"function": function,
2322+
}
2323+
for function in functions
2324+
]
23132325

2314-
return _convert_completion_to_chat(
2315-
llama.create_completion(
2316-
prompt=prompt,
2317-
temperature=temperature,
2318-
top_p=top_p,
2319-
top_k=top_k,
2320-
min_p=min_p,
2321-
typical_p=typical_p,
2322-
stream=stream,
2323-
stop=stop,
2324-
max_tokens=max_tokens,
2325-
presence_penalty=presence_penalty,
2326-
frequency_penalty=frequency_penalty,
2327-
repeat_penalty=repeat_penalty,
2328-
tfs_z=tfs_z,
2329-
mirostat_mode=mirostat_mode,
2330-
mirostat_tau=mirostat_tau,
2331-
mirostat_eta=mirostat_eta,
2332-
model=model,
2333-
logits_processor=logits_processor,
2334-
grammar=grammar,
2335-
),
2326+
# Convert legacy function_call to tool_choice
2327+
if function_call is not None:
2328+
if isinstance(function_call, str) and (
2329+
function_call == "none" or function_call == "auto"
2330+
):
2331+
tool_choice = function_call
2332+
if isinstance(function_call, dict) and "name" in function_call:
2333+
tool_choice = {
2334+
"type": "function",
2335+
"function": {
2336+
"name": function_call["name"],
2337+
},
2338+
}
2339+
2340+
tool = None
2341+
if tool_choice is not None and isinstance(tool_choice, dict) and tools is not None:
2342+
name = tool_choice["function"]["name"]
2343+
tool = next((t for t in tools if t["function"]["name"] == name), None)
2344+
if tool is None:
2345+
raise ValueError(f"Tool choice '{name}' not found in tools.")
2346+
schema = tool["function"]["parameters"]
2347+
try:
2348+
# create grammar from json schema
2349+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
2350+
json.dumps(schema), verbose=llama.verbose
2351+
)
2352+
except Exception as e:
2353+
grammar = llama_grammar.LlamaGrammar.from_string(
2354+
llama_grammar.JSON_GBNF, verbose=llama.verbose
2355+
)
2356+
2357+
completion_or_chunks = llama.create_completion(
2358+
prompt=prompt,
2359+
temperature=temperature,
2360+
top_p=top_p,
2361+
top_k=top_k,
2362+
min_p=min_p,
2363+
typical_p=typical_p,
2364+
logprobs=top_logprobs if logprobs else None,
23362365
stream=stream,
2366+
stop=stop,
2367+
seed=seed,
2368+
max_tokens=max_tokens,
2369+
presence_penalty=presence_penalty,
2370+
frequency_penalty=frequency_penalty,
2371+
repeat_penalty=repeat_penalty,
2372+
tfs_z=tfs_z,
2373+
mirostat_mode=mirostat_mode,
2374+
mirostat_tau=mirostat_tau,
2375+
mirostat_eta=mirostat_eta,
2376+
model=model,
2377+
logits_processor=logits_processor,
2378+
grammar=grammar,
2379+
logit_bias=logit_bias,
23372380
)
2381+
if tool is not None:
2382+
tool_name = tool["function"]["name"]
2383+
return _convert_completion_to_chat_function(
2384+
tool_name, completion_or_chunks, stream
2385+
)
2386+
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
23382387

23392388
@staticmethod
23402389
def _load_image(image_url: str) -> bytes:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.