Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d389d64

Browse filesBrowse files
committed
2 parents 06548c5 + 4ff8def commit d389d64
Copy full SHA for d389d64

File tree

Expand file treeCollapse file tree

3 files changed

+46
-55
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+46
-55
lines changed

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+35-11Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,16 @@ def _map_roles(
7373

7474

7575
def _format_llama2(
76-
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
76+
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
7777
) -> str:
7878
"""Format the prompt with the llama2 style."""
79+
seps = [sep, sep2]
7980
ret = system_message + sep
80-
for role, message in messages:
81-
if message:
82-
ret += role + message + " "
81+
for i, (role, message) in enumerate(messages):
82+
if system_message and i == 0:
83+
ret += message + seps[i % 2]
84+
elif message:
85+
ret += role + message + " " + seps[i % 2]
8386
else:
8487
ret += role + " "
8588
return ret
@@ -324,19 +327,20 @@ def get_chat_format(name: str):
324327
)
325328

326329

330+
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
331+
# system prompt is "embedded" in the first message
327332
@register_chat_format("llama-2")
328333
def format_llama2(
329334
messages: List[llama_types.ChatCompletionRequestMessage],
330335
**kwargs: Any,
331336
) -> ChatFormatterResponse:
332-
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
333-
_roles = dict(user="[INST]", assistant="[/INST]")
334-
_sep = "\n\n"
335-
system_message = _get_system_message(messages)
336-
system_message = _system_template.format(system_message=system_message)
337+
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
338+
_roles = dict(user="<s>[INST]", assistant="[/INST]")
337339
_messages = _map_roles(messages, _roles)
338-
_messages.append((_roles["assistant"], None))
339-
_prompt = _format_llama2(system_message, _messages, _sep)
340+
system_message = _get_system_message(messages)
341+
if system_message:
342+
system_message = _system_template.format(system_message=system_message)
343+
_prompt = _format_llama2(system_message, _messages, " ", "</s>") + "[/INST]"
340344
return ChatFormatterResponse(prompt=_prompt)
341345

342346

@@ -506,6 +510,26 @@ def format_chatml(
506510
_prompt = _format_chatml(system_message, _messages, _sep)
507511
return ChatFormatterResponse(prompt=_prompt)
508512

513+
# eg, export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.1
514+
@register_chat_format("autotokenizer")
515+
def format_autotokenizer(
516+
messages: List[llama_types.ChatCompletionRequestMessage],
517+
**kwargs: Any,
518+
) -> ChatFormatterResponse:
519+
# https://huggingface.co/docs/transformers/main/chat_templating
520+
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
521+
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
522+
import os
523+
from transformers import AutoTokenizer
524+
huggingFaceModel = os.getenv("HF_MODEL") # eg, mistralai/Mistral-7B-Instruct-v0.1
525+
print(huggingFaceModel)
526+
if not huggingFaceModel:
527+
raise Exception("HF_MODEL needs to be set in env to use chat format 'autotokenizer'")
528+
tokenizer = AutoTokenizer.from_pretrained(huggingFaceModel)
529+
tokenizer.use_default_system_prompt = False
530+
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
531+
# Return formatted prompt and eos token by default
532+
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
509533

510534
@register_chat_completion_handler("functionary")
511535
def functionary_chat_handler(

‎llama_cpp/llama_cpp.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cpp.py
+10-43Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def llama_kv_cache_clear(ctx: llama_context_p):
827827
# llama_pos p1);
828828
def llama_kv_cache_seq_rm(
829829
ctx: llama_context_p,
830-
seq_id: llama_seq_id,
830+
seq_id: Union[llama_seq_id, int],
831831
p0: Union[llama_pos, int],
832832
p1: Union[llama_pos, int],
833833
):
@@ -855,8 +855,8 @@ def llama_kv_cache_seq_rm(
855855
# llama_pos p1);
856856
def llama_kv_cache_seq_cp(
857857
ctx: llama_context_p,
858-
seq_id_src: llama_seq_id,
859-
seq_id_dst: llama_seq_id,
858+
seq_id_src: Union[llama_seq_id, int],
859+
seq_id_dst: Union[llama_seq_id, int],
860860
p0: Union[llama_pos, int],
861861
p1: Union[llama_pos, int],
862862
):
@@ -879,7 +879,7 @@ def llama_kv_cache_seq_cp(
879879
# llama_seq_id seq_id);
880880
def llama_kv_cache_seq_keep(
881881
ctx: llama_context_p,
882-
seq_id: llama_seq_id,
882+
seq_id: Union[llama_seq_id, int],
883883
):
884884
return _lib.llama_kv_cache_seq_keep(ctx, seq_id)
885885

@@ -900,7 +900,7 @@ def llama_kv_cache_seq_keep(
900900
# llama_pos delta);
901901
def llama_kv_cache_seq_shift(
902902
ctx: llama_context_p,
903-
seq_id: llama_seq_id,
903+
seq_id: Union[llama_seq_id, int],
904904
p0: Union[llama_pos, int],
905905
p1: Union[llama_pos, int],
906906
delta: Union[llama_pos, int],
@@ -1204,7 +1204,7 @@ def llama_get_embeddings(
12041204

12051205

12061206
# LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
1207-
def llama_token_get_text(model: llama_model_p, token: llama_token) -> bytes:
1207+
def llama_token_get_text(model: llama_model_p, token: Union[llama_token, int]) -> bytes:
12081208
return _lib.llama_token_get_text(model, token)
12091209

12101210

@@ -1213,7 +1213,7 @@ def llama_token_get_text(model: llama_model_p, token: llama_token) -> bytes:
12131213

12141214

12151215
# LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
1216-
def llama_token_get_score(model: llama_model_p, token: llama_token) -> float:
1216+
def llama_token_get_score(model: llama_model_p, token: Union[llama_token, int]) -> float:
12171217
return _lib.llama_token_get_score(model, token)
12181218

12191219

@@ -1222,7 +1222,7 @@ def llama_token_get_score(model: llama_model_p, token: llama_token) -> float:
12221222

12231223

12241224
# LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
1225-
def llama_token_get_type(model: llama_model_p, token: llama_token) -> int:
1225+
def llama_token_get_type(model: llama_model_p, token: Union[llama_token, int]) -> int:
12261226
return _lib.llama_token_get_type(model, token)
12271227

12281228

@@ -1302,39 +1302,6 @@ def llama_token_eot(model: llama_model_p) -> int:
13021302
# //
13031303

13041304

1305-
# // Convert the provided text into tokens.
1306-
# // The tokens pointer must be large enough to hold the resulting tokens.
1307-
# // Returns the number of tokens on success, no more than n_max_tokens
1308-
# // Returns a negative number on failure - the number of tokens that would have been returned
1309-
# LLAMA_API int llama_tokenize(
1310-
# const struct llama_model * model,
1311-
# const char * text,
1312-
# int text_len,
1313-
# llama_token * tokens,
1314-
# int n_max_tokens,
1315-
# bool add_bos);
1316-
def llama_tokenize(
1317-
model: llama_model_p,
1318-
text: bytes,
1319-
text_len: Union[c_int, int],
1320-
tokens, # type: Array[llama_token]
1321-
n_max_tokens: Union[c_int, int],
1322-
add_bos: Union[c_bool, bool],
1323-
) -> int:
1324-
return _lib.llama_tokenize(model, text, text_len, tokens, n_max_tokens, add_bos)
1325-
1326-
1327-
_lib.llama_tokenize.argtypes = [
1328-
llama_model_p,
1329-
c_char_p,
1330-
c_int,
1331-
llama_token_p,
1332-
c_int,
1333-
c_bool,
1334-
]
1335-
_lib.llama_tokenize.restype = c_int
1336-
1337-
13381305
# /// @details Convert the provided text into tokens.
13391306
# /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
13401307
# /// @return Returns the number of tokens on success, no more than n_max_tokens
@@ -1386,7 +1353,7 @@ def llama_tokenize(
13861353
# int length);
13871354
def llama_token_to_piece(
13881355
model: llama_model_p,
1389-
token: llama_token,
1356+
token: Union[llama_token, int],
13901357
buf: Union[c_char_p, bytes],
13911358
length: Union[c_int, int],
13921359
) -> int:
@@ -1835,7 +1802,7 @@ def llama_sample_token(
18351802
def llama_grammar_accept_token(
18361803
ctx: llama_context_p,
18371804
grammar: llama_grammar_p,
1838-
token: llama_token,
1805+
token: Union[llama_token, int],
18391806
) -> None:
18401807
_lib.llama_grammar_accept_token(ctx, grammar, token)
18411808

‎vendor/llama.cpp

Copy file name to clipboard

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.