Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9018270

Browse filesBrowse files
feat: Integrate functionary v1.4 and v2 models + add custom tokenizer support to Llama class (abetlen#1078)
* convert functionary-v1 chat handler to use hf autotokenizer * add hf_tokenizer + inteegrate functionary-v1.4 prompt template * integrate functionary v2 prompt template * update readme * set up parallel function calling wip * set up parallel function calling * Update README.md * Update README.md * refactor tokenizers * include old functionary handler for backward compatibility * add hf_tokenizer_path in server ModelSettings * convert functionary-v1 chat handler to use hf autotokenizer * add hf_tokenizer + inteegrate functionary-v1.4 prompt template * integrate functionary v2 prompt template * update readme * set up parallel function calling wip * resolve merge conflict * Update README.md * Update README.md * refactor tokenizers * include old functionary handler for backward compatibility * add hf_tokenizer_path in server ModelSettings * Cleanup PR, fix breaking changes * Use hf_pretrained_model_name_or_path for tokenizer * fix hf tokenizer in streaming * update README * refactor offset mapping --------- Co-authored-by: Andrei <abetlen@gmail.com>
1 parent 34f3104 commit 9018270
Copy full SHA for 9018270

File tree

Expand file treeCollapse file tree

4 files changed

+525
-34
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+525
-34
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+8-11Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -293,19 +293,16 @@ To constrain the response to a specific JSON Schema, you can use the `schema` pr
293293

294294
The high-level API also provides a simple interface for function calling.
295295

296-
Note that the only model that supports full function calling at this time is "functionary".
297-
The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
296+
The only set of models that supports full function calling at this time is [functionary](https://github.com/MeetKai/functionary). The various gguf-converted files for this set of models can be found [here](https://huggingface.co/meetkai). Functionary is able to intelligently call functions and also analyze any provided function outputs to generate coherent responses. All v2 models of functionary supports **parallel function calling**. You can provide either `functionary-v1` or `functionary-v2` for the `chat_format` when initializing the Llama class.
297+
298+
Note that due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is required to provide HF Tokenizer for functionary. The `LlamaHFTokenizer` class can be initialized and passed into the Llama class. This will override the default llama.cpp tokenizer used in Llama class. The tokenizer files are already included in the respective HF repositories hosting the gguf files.
298299

299300
```python
300-
>>> from llama_cpp import Llama
301-
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary")
301+
>>> from llama_cpp import Llama, LlamaHFTokenizer
302+
>>> tokenizer = LlamaHFTokenizer.from_pretrained("path/to/functionary/")
303+
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", tokenizer=tokenizer, chat_format="functionary-v2")
302304
>>> llm.create_chat_completion(
303305
messages = [
304-
{
305-
"role": "system",
306-
"content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"
307-
308-
},
309306
{
310307
"role": "user",
311308
"content": "Extract Jason is 25 years old"
@@ -332,12 +329,12 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
332329
}
333330
}
334331
}],
335-
tool_choice=[{
332+
tool_choice={
336333
"type": "function",
337334
"function": {
338335
"name": "UserDetail"
339336
}
340-
}]
337+
},
341338
)
342339
```
343340

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+79-22Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import sys
5+
import abc
56
import uuid
67
import time
78
import multiprocessing
@@ -14,11 +15,14 @@
1415
Iterator,
1516
Deque,
1617
Callable,
18+
Any,
1719
)
1820
from collections import deque
1921

2022
import ctypes
2123

24+
from llama_cpp.llama_types import List
25+
2226
from .llama_types import *
2327
from .llama_grammar import LlamaGrammar
2428
from .llama_cache import (
@@ -95,6 +99,8 @@ def __init__(
9599
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
96100
# Speculative Decoding
97101
draft_model: Optional[LlamaDraftModel] = None,
102+
# Tokenizer Override
103+
tokenizer: Optional[BaseLlamaTokenizer] = None,
98104
# Misc
99105
verbose: bool = True,
100106
# Extra Params
@@ -159,6 +165,7 @@ def __init__(
159165
chat_format: String specifying the chat format to use when calling create_chat_completion.
160166
chat_handler: Optional chat handler to use when calling create_chat_completion.
161167
draft_model: Optional draft model to use for speculative decoding.
168+
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
162169
verbose: Print verbose output to stderr.
163170
164171
Raises:
@@ -235,6 +242,7 @@ def __init__(
235242
self.n_threads_batch = n_threads_batch or max(
236243
multiprocessing.cpu_count() // 2, 1
237244
)
245+
238246
# Context Params
239247
self.context_params = llama_cpp.llama_context_default_params()
240248
self.context_params.seed = seed
@@ -286,6 +294,10 @@ def __init__(
286294
self._model = _LlamaModel(
287295
path_model=self.model_path, params=self.model_params, verbose=self.verbose
288296
)
297+
298+
# Override tokenizer
299+
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
300+
289301
# Set the default value for the context and correct the batch
290302
if n_ctx == 0:
291303
n_ctx = self._model.n_ctx_train()
@@ -431,18 +443,19 @@ def tokenize(
431443
Returns:
432444
A list of tokens.
433445
"""
434-
return self._model.tokenize(text, add_bos, special)
446+
return self.tokenizer_.tokenize(text, add_bos, special)
435447

436-
def detokenize(self, tokens: List[int]) -> bytes:
448+
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
437449
"""Detokenize a list of tokens.
438450
439451
Args:
440452
tokens: The list of tokens to detokenize.
453+
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided
441454
442455
Returns:
443456
The detokenized string.
444457
"""
445-
return self._model.detokenize(tokens)
458+
return self.tokenizer_.detokenize(tokens, prev_tokens)
446459

447460
def set_cache(self, cache: Optional[BaseLlamaCache]):
448461
"""Set the cache.
@@ -935,7 +948,8 @@ def logit_bias_processor(
935948

936949
if stream:
937950
remaining_tokens = completion_tokens[returned_tokens:]
938-
remaining_text = self.detokenize(remaining_tokens)
951+
prev_tokens = completion_tokens[:returned_tokens]
952+
remaining_text = self.detokenize(completion_tokens, prev_tokens)
939953
remaining_length = len(remaining_text)
940954

941955
# We want to avoid yielding any characters from
@@ -957,13 +971,13 @@ def logit_bias_processor(
957971
for token in remaining_tokens:
958972
if token == self.token_bos():
959973
continue
960-
token_end_position += len(self.detokenize([token]))
974+
token_end_position += len(remaining_text)
961975
# Check if stop sequence is in the token
962976
if token_end_position > (
963977
remaining_length - first_stop_position
964978
):
965979
break
966-
token_str = self.detokenize([token]).decode(
980+
token_str = remaining_text.decode(
967981
"utf-8", errors="ignore"
968982
)
969983
text_offset = len(prompt) + len(
@@ -988,11 +1002,7 @@ def logit_bias_processor(
9881002
}
9891003
top_logprob.update({token_str: current_logprobs[int(token)]})
9901004
logprobs_or_none = {
991-
"tokens": [
992-
self.detokenize([token]).decode(
993-
"utf-8", errors="ignore"
994-
)
995-
],
1005+
"tokens": [token_str],
9961006
"text_offset": [text_offset],
9971007
"token_logprobs": [current_logprobs[int(token)]],
9981008
"top_logprobs": [top_logprob],
@@ -1005,9 +1015,7 @@ def logit_bias_processor(
10051015
"model": model_name,
10061016
"choices": [
10071017
{
1008-
"text": self.detokenize([token]).decode(
1009-
"utf-8", errors="ignore"
1010-
),
1018+
"text": token_str,
10111019
"index": 0,
10121020
"logprobs": logprobs_or_none,
10131021
"finish_reason": None,
@@ -1019,7 +1027,7 @@ def logit_bias_processor(
10191027
decode_success = False
10201028
for i in range(1, len(remaining_tokens) + 1):
10211029
try:
1022-
bs = self.detokenize(remaining_tokens[:i])
1030+
bs = remaining_text
10231031
ts = bs.decode("utf-8")
10241032
decode_success = True
10251033
break
@@ -1055,6 +1063,7 @@ def logit_bias_processor(
10551063

10561064
if len(completion_tokens) >= max_tokens:
10571065
text = self.detokenize(completion_tokens)
1066+
10581067
finish_reason = "length"
10591068
break
10601069

@@ -1693,8 +1702,8 @@ def n_vocab(self) -> int:
16931702
"""Return the vocabulary size."""
16941703
return self._model.n_vocab()
16951704

1696-
def tokenizer(self) -> "LlamaTokenizer":
1697-
"""Return the tokenizer for this model."""
1705+
def tokenizer(self) -> LlamaTokenizer:
1706+
"""Return the llama tokenizer for this model."""
16981707
return LlamaTokenizer(self)
16991708

17001709
def token_eos(self) -> int:
@@ -1738,23 +1747,71 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
17381747
return longest_prefix
17391748

17401749

1741-
class LlamaTokenizer:
1750+
class BaseLlamaTokenizer(abc.ABC):
1751+
@abc.abstractmethod
1752+
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
1753+
raise NotImplementedError
1754+
1755+
@abc.abstractmethod
1756+
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
1757+
raise NotImplementedError
1758+
1759+
1760+
class LlamaTokenizer(BaseLlamaTokenizer):
17421761
def __init__(self, llama: Llama):
17431762
self.llama = llama
1763+
self._model = llama._model # type: ignore
1764+
1765+
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
1766+
return self._model.tokenize(text, add_bos=add_bos, special=special)
1767+
1768+
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
1769+
if prev_tokens is not None:
1770+
return self._model.detokenize(tokens[len(prev_tokens):])
1771+
else:
1772+
return self._model.detokenize(tokens)
17441773

1745-
def encode(self, text: str, add_bos: bool = True) -> List[int]:
1746-
return self.llama.tokenize(
1747-
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
1774+
def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]:
1775+
return self.tokenize(
1776+
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
17481777
)
17491778

17501779
def decode(self, tokens: List[int]) -> str:
1751-
return self.llama.detokenize(tokens).decode("utf-8", errors="ignore")
1780+
return self.detokenize(tokens).decode("utf-8", errors="ignore")
17521781

17531782
@classmethod
17541783
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
17551784
return cls(Llama(model_path=path, vocab_only=True))
17561785

17571786

1787+
class LlamaHFTokenizer(BaseLlamaTokenizer):
1788+
def __init__(self, hf_tokenizer: Any):
1789+
self.hf_tokenizer = hf_tokenizer
1790+
1791+
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
1792+
return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special)
1793+
1794+
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
1795+
if prev_tokens is not None:
1796+
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
1797+
prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore")
1798+
return text[len(prev_text):]
1799+
else:
1800+
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
1801+
1802+
@classmethod
1803+
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
1804+
try:
1805+
from transformers import AutoTokenizer
1806+
except ImportError:
1807+
raise ImportError(
1808+
"The `transformers` library is required to use the `HFTokenizer`."
1809+
"You can install it with `pip install transformers`."
1810+
)
1811+
hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
1812+
return cls(hf_tokenizer)
1813+
1814+
17581815
class LlamaState:
17591816
def __init__(
17601817
self,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.