Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1ae3abb

Browse filesBrowse files
committed
fix: missing logprobs in response, incorrect response type for functionary, minor type issues. Closes abetlen#1328 Closes abetlen#1314
1 parent 9111b6e commit 1ae3abb
Copy full SHA for 1ae3abb

File tree

Expand file treeCollapse file tree

1 file changed

+29
-19
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+29
-19
lines changed

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+29-19Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dataclasses
77
import random
88
import string
9-
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol
9+
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
1010

1111
import jinja2
1212

@@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
338338
}
339339
],
340340
},
341+
"logprobs": None,
341342
"finish_reason": "tool_calls",
342343
}
343344
],
@@ -1191,7 +1192,6 @@ def format_mistral_instruct(
11911192
elif (
11921193
message["role"] == "assistant"
11931194
and message["content"] is not None
1194-
and isinstance(message["content"], str)
11951195
):
11961196
prompt += " [/INST]" + message["content"] + eos
11971197
prompt += " [/INST]"
@@ -1263,7 +1263,7 @@ def format_gemma(
12631263
**kwargs: Any,
12641264
) -> ChatFormatterResponse:
12651265
system_message = _get_system_message(messages)
1266-
if system_message is not None and system_message != "":
1266+
if system_message != "":
12671267
logger.debug(
12681268
"`role='system'` messages are not allowed on Google's Gemma models."
12691269
)
@@ -1628,6 +1628,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
16281628
}
16291629
],
16301630
},
1631+
"logprobs": None,
16311632
"finish_reason": "tool_calls",
16321633
}
16331634
],
@@ -1909,14 +1910,14 @@ def get_grammar(function_call):
19091910
return grammar
19101911

19111912
def create_completion(stop):
1912-
completion: llama_types.Completion = llama.create_completion(
1913+
completion = cast(llama_types.Completion, llama.create_completion(
19131914
prompt=prompt,
19141915
temperature=temperature,
19151916
top_p=top_p,
19161917
top_k=top_k,
19171918
min_p=min_p,
19181919
typical_p=typical_p,
1919-
stream=stream,
1920+
stream=False,
19201921
stop=stop,
19211922
max_tokens=max_tokens,
19221923
presence_penalty=presence_penalty,
@@ -1929,7 +1930,7 @@ def create_completion(stop):
19291930
model=model,
19301931
logits_processor=logits_processor,
19311932
grammar=grammar,
1932-
)
1933+
))
19331934

19341935
return completion
19351936

@@ -2050,7 +2051,7 @@ def create_completion(stop):
20502051
assert "usage" in completion
20512052
assert len(function_calls) == len(function_bodies)
20522053

2053-
tool_calls = []
2054+
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
20542055
for function_call, function_body in zip(function_calls, function_bodies):
20552056
tool_calls.append(
20562057
{
@@ -2070,6 +2071,12 @@ def create_completion(stop):
20702071
)
20712072

20722073
# TODO: support stream mode
2074+
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
2075+
"function_call": {
2076+
"name": tool_calls[0]["function"]["name"],
2077+
"arguments": tool_calls[0]["function"]["arguments"],
2078+
}
2079+
} if len(tool_calls) == 1 else {}
20732080
return llama_types.CreateChatCompletionResponse(
20742081
id="chat" + completion["id"],
20752082
object="chat.completion",
@@ -2078,14 +2085,12 @@ def create_completion(stop):
20782085
choices=[
20792086
{
20802087
"index": 0,
2088+
"logprobs": None,
20812089
"message": {
20822090
"role": "assistant",
20832091
"content": None if content == "" else content,
2084-
"function_call": {
2085-
"name": tool_calls[0]["function"]["name"],
2086-
"arguments": tool_calls[0]["function"]["arguments"],
2087-
} if len(tool_calls) > 0 else None,
2088-
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
2092+
"tool_calls": tool_calls,
2093+
**function_call_dict,
20892094
},
20902095
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
20912096
}
@@ -2565,8 +2570,8 @@ def chatml_function_calling(
25652570
tool_name = text[len("functions.") :]
25662571
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
25672572
if not stream:
2568-
completions = []
2569-
completions_tool_name = []
2573+
completions: List[llama_types.CreateCompletionResponse] = []
2574+
completions_tool_name: List[str] = []
25702575
while tool is not None:
25712576
prompt += f"functions.{tool_name}:\n"
25722577
try:
@@ -2603,6 +2608,7 @@ def chatml_function_calling(
26032608
logits_processor=logits_processor,
26042609
grammar=grammar,
26052610
)
2611+
completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
26062612
completions.append(completion_or_chunks)
26072613
completions_tool_name.append(tool_name)
26082614
prompt += completion_or_chunks["choices"][0]["text"]
@@ -2631,14 +2637,15 @@ def chatml_function_calling(
26312637
follow_up_gbnf_tool_grammar, verbose=llama.verbose
26322638
),
26332639
)
2640+
response = cast(llama_types.CreateCompletionResponse, response)
26342641

26352642
tool_name = response["choices"][0]["text"][len("functions.") :]
26362643
tool = next(
26372644
(tool for tool in tools if tool["function"]["name"] == tool_name), None
26382645
)
26392646

26402647
# Merge completions
2641-
function_call = {
2648+
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
26422649
"function_call": {
26432650
"name": tool_name,
26442651
"arguments": completions[0]["choices"][0]["text"],
@@ -2653,6 +2660,7 @@ def chatml_function_calling(
26532660
{
26542661
"finish_reason": "tool_calls",
26552662
"index": 0,
2663+
"logprobs": None,
26562664
"message": {
26572665
"role": "assistant",
26582666
"content": None,
@@ -2673,20 +2681,22 @@ def chatml_function_calling(
26732681
zip(completions_tool_name, completions)
26742682
)
26752683
],
2676-
**function_call
2684+
**function_call_dict
26772685
},
26782686
}
26792687
],
26802688
"usage": {
26812689
"completion_tokens": sum(
2682-
completion["usage"]["completion_tokens"]
2690+
completion["usage"]["completion_tokens"] if "usage" in completion else 0
26832691
for completion in completions
26842692
),
26852693
"prompt_tokens": sum(
2686-
completion["usage"]["prompt_tokens"] for completion in completions
2694+
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
2695+
for completion in completions
26872696
),
26882697
"total_tokens": sum(
2689-
completion["usage"]["total_tokens"] for completion in completions
2698+
completion["usage"]["total_tokens"] if "usage" in completion else 0
2699+
for completion in completions
26902700
),
26912701
},
26922702
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.