Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Integrate functionary v1.4 and v2 models + add HF AutoTokenizer as optional parameter in llama.create_completion #1078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5903e2f
convert functionary-v1 chat handler to use hf autotokenizer
jeffrey-fong Jan 10, 2024
485f129
add hf_tokenizer + inteegrate functionary-v1.4 prompt template
jeffrey-fong Jan 11, 2024
9580bee
integrate functionary v2 prompt template
jeffrey-fong Jan 11, 2024
bb48a83
update readme
jeffrey-fong Jan 11, 2024
0369931
set up parallel function calling wip
jeffrey-fong Jan 11, 2024
9540df9
set up parallel function calling
jeffrey-fong Jan 11, 2024
c71863c
Update README.md
jeffrey-fong Jan 12, 2024
4cf8736
Update README.md
jeffrey-fong Jan 12, 2024
ae7009b
refactor tokenizers
jeffrey-fong Jan 23, 2024
ebb4ec0
include old functionary handler for backward compatibility
jeffrey-fong Jan 23, 2024
9594d5c
add hf_tokenizer_path in server ModelSettings
jeffrey-fong Jan 23, 2024
7ea2e6e
Merge branch 'main' into integrate-functionary
jeffrey-fong Jan 30, 2024
43b4529
convert functionary-v1 chat handler to use hf autotokenizer
jeffrey-fong Jan 10, 2024
c9c6947
add hf_tokenizer + inteegrate functionary-v1.4 prompt template
jeffrey-fong Jan 11, 2024
f912c62
integrate functionary v2 prompt template
jeffrey-fong Jan 11, 2024
3b5fe39
update readme
jeffrey-fong Jan 11, 2024
4dd6b62
set up parallel function calling wip
jeffrey-fong Jan 11, 2024
03b68fe
resolve merge conflict
jeffrey-fong Jan 31, 2024
2957baf
Update README.md
jeffrey-fong Jan 12, 2024
7a98b04
Update README.md
jeffrey-fong Jan 12, 2024
bc9447b
refactor tokenizers
jeffrey-fong Jan 23, 2024
8d334df
include old functionary handler for backward compatibility
jeffrey-fong Jan 23, 2024
8d08b2d
add hf_tokenizer_path in server ModelSettings
jeffrey-fong Jan 23, 2024
5647bea
resolve merge conflict
jeffrey-fong Jan 31, 2024
951a6c9
Merge branch 'main' into integrate-functionary
abetlen Jan 31, 2024
1825688
Merge branch 'main' into integrate-functionary
abetlen Jan 31, 2024
3657cba
Cleanup PR, fix breaking changes
abetlen Jan 31, 2024
a79743b
Use hf_pretrained_model_name_or_path for tokenizer
abetlen Jan 31, 2024
7b36eb3
fix hf tokenizer in streaming
jeffrey-fong Feb 1, 2024
5ea9b19
update README
jeffrey-fong Feb 1, 2024
c8b5257
pull from main
jeffrey-fong Feb 2, 2024
24eb0db
refactor offset mapping
jeffrey-fong Feb 2, 2024
28c401c
Merge branch 'main' into integrate-functionary
abetlen Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions 19 README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,16 @@ To constrain the response to a specific JSON Schema, you can use the `schema` pr

The high-level API also provides a simple interface for function calling.

Note that the only model that supports full function calling at this time is "functionary".
The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
The only set of models that supports full function calling at this time is [functionary](https://github.com/MeetKai/functionary). The various gguf-converted files for this set of models can be found [here](https://huggingface.co/meetkai). Functionary is able to intelligently call functions and also analyze any provided function outputs to generate coherent responses. All v2 models of functionary supports **parallel function calling**. You can provide either `functionary-v1` or `functionary-v2` for the `chat_format` when initializing the Llama class.

Note that due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is required to provide HF Tokenizer for functionary. The `LlamaHFTokenizer` class can be initialized and passed into the Llama class. This will override the default llama.cpp tokenizer used in Llama class. The tokenizer files are already included in the respective HF repositories hosting the gguf files.

```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary")
>>> from llama_cpp import Llama, LlamaHFTokenizer
>>> tokenizer = LlamaHFTokenizer.from_pretrained("path/to/functionary/")
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", tokenizer=tokenizer, chat_format="functionary-v2")
>>> llm.create_chat_completion(
messages = [
{
"role": "system",
"content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"

},
{
"role": "user",
"content": "Extract Jason is 25 years old"
Expand All @@ -332,12 +329,12 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
}
}
}],
tool_choice=[{
tool_choice={
"type": "function",
"function": {
"name": "UserDetail"
}
}]
},
)
```

Expand Down
101 changes: 79 additions & 22 deletions 101 llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
import abc
import uuid
import time
import multiprocessing
Expand All @@ -14,11 +15,14 @@
Iterator,
Deque,
Callable,
Any,
)
from collections import deque

import ctypes

from llama_cpp.llama_types import List

from .llama_types import *
from .llama_grammar import LlamaGrammar
from .llama_cache import (
Expand Down Expand Up @@ -95,6 +99,8 @@ def __init__(
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None,
# Tokenizer Override
tokenizer: Optional[BaseLlamaTokenizer] = None,
# Misc
verbose: bool = True,
# Extra Params
Expand Down Expand Up @@ -159,6 +165,7 @@ def __init__(
chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding.
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
verbose: Print verbose output to stderr.

Raises:
Expand Down Expand Up @@ -235,6 +242,7 @@ def __init__(
self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1
)

# Context Params
self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed
Expand Down Expand Up @@ -286,6 +294,10 @@ def __init__(
self._model = _LlamaModel(
path_model=self.model_path, params=self.model_params, verbose=self.verbose
)

# Override tokenizer
self.tokenizer_ = tokenizer or LlamaTokenizer(self)

# Set the default value for the context and correct the batch
if n_ctx == 0:
n_ctx = self._model.n_ctx_train()
Expand Down Expand Up @@ -431,18 +443,19 @@ def tokenize(
Returns:
A list of tokens.
"""
return self._model.tokenize(text, add_bos, special)
return self.tokenizer_.tokenize(text, add_bos, special)

def detokenize(self, tokens: List[int]) -> bytes:
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
"""Detokenize a list of tokens.

Args:
tokens: The list of tokens to detokenize.
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided

Returns:
The detokenized string.
"""
return self._model.detokenize(tokens)
return self.tokenizer_.detokenize(tokens, prev_tokens)

def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
Expand Down Expand Up @@ -935,7 +948,8 @@ def logit_bias_processor(

if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens)
prev_tokens = completion_tokens[:returned_tokens]
remaining_text = self.detokenize(completion_tokens, prev_tokens)
remaining_length = len(remaining_text)

# We want to avoid yielding any characters from
Expand All @@ -957,13 +971,13 @@ def logit_bias_processor(
for token in remaining_tokens:
if token == self.token_bos():
continue
token_end_position += len(self.detokenize([token]))
token_end_position += len(remaining_text)
# Check if stop sequence is in the token
if token_end_position > (
remaining_length - first_stop_position
):
break
token_str = self.detokenize([token]).decode(
token_str = remaining_text.decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
Expand All @@ -988,11 +1002,7 @@ def logit_bias_processor(
}
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"tokens": [token_str],
"text_offset": [text_offset],
"token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob],
Expand All @@ -1005,9 +1015,7 @@ def logit_bias_processor(
"model": model_name,
"choices": [
{
"text": self.detokenize([token]).decode(
"utf-8", errors="ignore"
),
"text": token_str,
"index": 0,
"logprobs": logprobs_or_none,
"finish_reason": None,
Expand All @@ -1019,7 +1027,7 @@ def logit_bias_processor(
decode_success = False
for i in range(1, len(remaining_tokens) + 1):
try:
bs = self.detokenize(remaining_tokens[:i])
bs = remaining_text
ts = bs.decode("utf-8")
decode_success = True
break
Expand Down Expand Up @@ -1055,6 +1063,7 @@ def logit_bias_processor(

if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens)

finish_reason = "length"
break

Expand Down Expand Up @@ -1693,8 +1702,8 @@ def n_vocab(self) -> int:
"""Return the vocabulary size."""
return self._model.n_vocab()

def tokenizer(self) -> "LlamaTokenizer":
"""Return the tokenizer for this model."""
def tokenizer(self) -> LlamaTokenizer:
"""Return the llama tokenizer for this model."""
return LlamaTokenizer(self)

def token_eos(self) -> int:
Expand Down Expand Up @@ -1738,23 +1747,71 @@ def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
return longest_prefix


class LlamaTokenizer:
class BaseLlamaTokenizer(abc.ABC):
@abc.abstractmethod
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
raise NotImplementedError

@abc.abstractmethod
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
raise NotImplementedError


class LlamaTokenizer(BaseLlamaTokenizer):
def __init__(self, llama: Llama):
self.llama = llama
self._model = llama._model # type: ignore

def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
return self._model.tokenize(text, add_bos=add_bos, special=special)

def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
if prev_tokens is not None:
return self._model.detokenize(tokens[len(prev_tokens):])
else:
return self._model.detokenize(tokens)

def encode(self, text: str, add_bos: bool = True) -> List[int]:
return self.llama.tokenize(
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]:
return self.tokenize(
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
)

def decode(self, tokens: List[int]) -> str:
return self.llama.detokenize(tokens).decode("utf-8", errors="ignore")
return self.detokenize(tokens).decode("utf-8", errors="ignore")

@classmethod
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
return cls(Llama(model_path=path, vocab_only=True))


class LlamaHFTokenizer(BaseLlamaTokenizer):
def __init__(self, hf_tokenizer: Any):
self.hf_tokenizer = hf_tokenizer

def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special)

def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
if prev_tokens is not None:
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore")
return text[len(prev_text):]
else:
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
try:
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library is required to use the `HFTokenizer`."
"You can install it with `pip install transformers`."
)
hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
return cls(hf_tokenizer)


class LlamaState:
def __init__(
self,
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.