@@ -2232,6 +2232,7 @@ def __call__(
22322232 typical_p : float = 1.0 ,
22332233 stream : bool = False ,
22342234 stop : Optional [Union [str , List [str ]]] = [],
2235+ seed : Optional [int ] = None ,
22352236 response_format : Optional [
22362237 llama_types .ChatCompletionRequestResponseFormat
22372238 ] = None ,
@@ -2246,6 +2247,9 @@ def __call__(
22462247 model : Optional [str ] = None ,
22472248 logits_processor : Optional [llama .LogitsProcessorList ] = None ,
22482249 grammar : Optional [llama .LlamaGrammar ] = None ,
2250+ logit_bias : Optional [Dict [str , float ]] = None ,
2251+ logprobs : Optional [bool ] = None ,
2252+ top_logprobs : Optional [int ] = None ,
22492253 ** kwargs , # type: ignore
22502254 ) -> Union [
22512255 llama_types .CreateChatCompletionResponse ,
@@ -2309,32 +2313,77 @@ def free_embed():
23092313 if response_format is not None and response_format ["type" ] == "json_object" :
23102314 grammar = _grammar_for_response_format (response_format )
23112315
2312- # TODO: Add function call support
2316+ # Convert legacy functions to tools
2317+ if functions is not None :
2318+ tools = [
2319+ {
2320+ "type" : "function" ,
2321+ "function" : function ,
2322+ }
2323+ for function in functions
2324+ ]
23132325
2314- return _convert_completion_to_chat (
2315- llama .create_completion (
2316- prompt = prompt ,
2317- temperature = temperature ,
2318- top_p = top_p ,
2319- top_k = top_k ,
2320- min_p = min_p ,
2321- typical_p = typical_p ,
2322- stream = stream ,
2323- stop = stop ,
2324- max_tokens = max_tokens ,
2325- presence_penalty = presence_penalty ,
2326- frequency_penalty = frequency_penalty ,
2327- repeat_penalty = repeat_penalty ,
2328- tfs_z = tfs_z ,
2329- mirostat_mode = mirostat_mode ,
2330- mirostat_tau = mirostat_tau ,
2331- mirostat_eta = mirostat_eta ,
2332- model = model ,
2333- logits_processor = logits_processor ,
2334- grammar = grammar ,
2335- ),
2326+ # Convert legacy function_call to tool_choice
2327+ if function_call is not None :
2328+ if isinstance (function_call , str ) and (
2329+ function_call == "none" or function_call == "auto"
2330+ ):
2331+ tool_choice = function_call
2332+ if isinstance (function_call , dict ) and "name" in function_call :
2333+ tool_choice = {
2334+ "type" : "function" ,
2335+ "function" : {
2336+ "name" : function_call ["name" ],
2337+ },
2338+ }
2339+
2340+ tool = None
2341+ if tool_choice is not None and isinstance (tool_choice , dict ) and tools is not None :
2342+ name = tool_choice ["function" ]["name" ]
2343+ tool = next ((t for t in tools if t ["function" ]["name" ] == name ), None )
2344+ if tool is None :
2345+ raise ValueError (f"Tool choice '{ name } ' not found in tools." )
2346+ schema = tool ["function" ]["parameters" ]
2347+ try :
2348+ # create grammar from json schema
2349+ grammar = llama_grammar .LlamaGrammar .from_json_schema (
2350+ json .dumps (schema ), verbose = llama .verbose
2351+ )
2352+ except Exception as e :
2353+ grammar = llama_grammar .LlamaGrammar .from_string (
2354+ llama_grammar .JSON_GBNF , verbose = llama .verbose
2355+ )
2356+
2357+ completion_or_chunks = llama .create_completion (
2358+ prompt = prompt ,
2359+ temperature = temperature ,
2360+ top_p = top_p ,
2361+ top_k = top_k ,
2362+ min_p = min_p ,
2363+ typical_p = typical_p ,
2364+ logprobs = top_logprobs if logprobs else None ,
23362365 stream = stream ,
2366+ stop = stop ,
2367+ seed = seed ,
2368+ max_tokens = max_tokens ,
2369+ presence_penalty = presence_penalty ,
2370+ frequency_penalty = frequency_penalty ,
2371+ repeat_penalty = repeat_penalty ,
2372+ tfs_z = tfs_z ,
2373+ mirostat_mode = mirostat_mode ,
2374+ mirostat_tau = mirostat_tau ,
2375+ mirostat_eta = mirostat_eta ,
2376+ model = model ,
2377+ logits_processor = logits_processor ,
2378+ grammar = grammar ,
2379+ logit_bias = logit_bias ,
23372380 )
2381+ if tool is not None :
2382+ tool_name = tool ["function" ]["name" ]
2383+ return _convert_completion_to_chat_function (
2384+ tool_name , completion_or_chunks , stream
2385+ )
2386+ return _convert_completion_to_chat (completion_or_chunks , stream = stream )
23382387
23392388 @staticmethod
23402389 def _load_image (image_url : str ) -> bytes :
0 commit comments