Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dbcf64c

Browse filesBrowse files
CISCabetlen
andauthored
feat: Support SPM infill (abetlen#1492)
* Support SPM infill * typo-- * one less layer of parenthesis necessary * new required internals * manually add bos/eos if model requires it * add bos even when unknown This is identical behaviour to llama.cpp I guess any model that doesn't use BOS is recent enough to have the add_bos_token metadata. * don't add bos/eos on non-infill pre-tokenized prompt * add tokenizer hack to remove leading space in suffix * I keep forgetting metadata are strings * check if bos exists * add example * add cls/sep instead of bos/eos for WPM vocab * simplify * color-code filtered suffix --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
1 parent e342161 commit dbcf64c
Copy full SHA for dbcf64c

File tree

Expand file treeCollapse file tree

3 files changed

+87
-27
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+87
-27
lines changed
+33Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import argparse
2+
3+
from llama_cpp import Llama
4+
5+
parser = argparse.ArgumentParser()
6+
parser.add_argument("-m", "--model", type=str, default="../models/7B/ggml-models.bin")
7+
parser.add_argument("-p", "--prompt", type=str, default="def add(")
8+
parser.add_argument("-s", "--suffix", type=str, default="\n return sum\n\n")
9+
parser.add_argument("-i", "--spm-infill", action='store_true')
10+
args = parser.parse_args()
11+
12+
llm = Llama(model_path=args.model, n_gpu_layers=-1, spm_infill=args.spm_infill)
13+
14+
output = llm.create_completion(
15+
temperature = 0.0,
16+
repeat_penalty = 1.0,
17+
prompt = args.prompt,
18+
suffix = args.suffix,
19+
)
20+
21+
# Models sometimes repeat suffix in response, attempt to filter that
22+
response = output["choices"][0]["text"]
23+
response_stripped = response.rstrip()
24+
unwanted_response_suffix = args.suffix.rstrip()
25+
unwanted_response_length = len(unwanted_response_suffix)
26+
27+
filtered = False
28+
if unwanted_response_suffix and response_stripped[-unwanted_response_length:] == unwanted_response_suffix:
29+
response = response_stripped[:-unwanted_response_length]
30+
filtered = True
31+
32+
print(f"Fill-in-Middle completion{' (filtered)' if filtered else ''}:\n\n{args.prompt}\033[32m{response}\033[{'33' if filtered else '0'}m{args.suffix}\033[0m")
33+

‎llama_cpp/_internals.py

Copy file name to clipboardExpand all lines: llama_cpp/_internals.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ def token_eot(self) -> int:
170170
assert self.model is not None
171171
return llama_cpp.llama_token_eot(self.model)
172172

173+
def add_bos_token(self) -> int:
174+
assert self.model is not None
175+
return llama_cpp.llama_add_bos_token(self.model)
176+
177+
def add_eos_token(self) -> int:
178+
assert self.model is not None
179+
return llama_cpp.llama_add_eos_token(self.model)
180+
173181
# Tokenization
174182

175183
def tokenize(self, text: bytes, add_bos: bool, special: bool):

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+46-27Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
type_k: Optional[int] = None,
116116
type_v: Optional[int] = None,
117117
# Misc
118+
spm_infill: bool = False,
118119
verbose: bool = True,
119120
# Extra Params
120121
**kwargs, # type: ignore
@@ -185,6 +186,7 @@ def __init__(
185186
verbose: Print verbose output to stderr.
186187
type_k: KV cache data type for K (default: f16)
187188
type_v: KV cache data type for V (default: f16)
189+
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
188190
189191
Raises:
190192
ValueError: If the model path does not exist.
@@ -343,6 +345,8 @@ def __init__(
343345
self.lora_scale = lora_scale
344346
self.lora_path = lora_path
345347

348+
self.spm_infill = spm_infill
349+
346350
if not os.path.exists(model_path):
347351
raise ValueError(f"Model path does not exist: {model_path}")
348352

@@ -972,14 +976,33 @@ def _create_completion(
972976

973977
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
974978
created: int = int(time.time())
979+
bos_token_id: int = self.token_bos()
980+
cls_token_id: int = self._model.token_cls()
981+
sep_token_id: int = self._model.token_sep()
975982
prefix_token_id: int = self._model.token_prefix()
976983
middle_token_id: int = self._model.token_middle()
977984
suffix_token_id: int = self._model.token_suffix()
985+
add_space_prefix: bool = self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
986+
bos_tokens: List[int] = [cls_token_id if cls_token_id != -1 else bos_token_id]
987+
eos_tokens: List[int] = [sep_token_id if sep_token_id != -1 else self.token_eos()]
988+
989+
if (isinstance(prompt, list) and suffix is None) or self._model.add_bos_token() == 0 or bos_tokens[:1] == [-1]:
990+
bos_tokens = []
991+
992+
if (isinstance(prompt, list) and suffix is None) or (self._model.add_eos_token() != 1 and sep_token_id == -1):
993+
eos_tokens = []
994+
995+
suffix_space_prefix: int = 0
996+
# Tokenizer hack to remove leading space
997+
if add_space_prefix and suffix_token_id >= 0 and suffix:
998+
suffix = "☺" + suffix
999+
suffix_space_prefix = 2
1000+
9781001
# If prompt is empty, initialize completion with BOS token to avoid
9791002
# detokenization including a space at the beginning of the completion
980-
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
1003+
completion_tokens: List[int] = [] if len(prompt) > 0 else [bos_token_id]
9811004
# Add blank space to start of prompt to match OG llama tokenizer
982-
prompt_tokens: List[int] = (
1005+
prefix_tokens: List[int] = (
9831006
(
9841007
[prefix_token_id]
9851008
if prefix_token_id >= 0 and suffix is not None
@@ -988,38 +1011,33 @@ def _create_completion(
9881011
+
9891012
(
9901013
(
991-
self.tokenize(prompt.encode("utf-8"), add_bos=(prefix_token_id < 0 or suffix is None), special=(prefix_token_id < 0 or suffix is None))
1014+
self.tokenize(prompt.encode("utf-8"), add_bos=False, special=(prefix_token_id < 0 or suffix is None))
9921015
if prompt != ""
993-
else (
994-
[]
995-
if prefix_token_id >= 0 and suffix is not None
996-
else [self.token_bos()]
997-
)
1016+
else []
9981017
)
9991018
if isinstance(prompt, str)
10001019
else prompt
10011020
)
1002-
+
1021+
)
1022+
suffix_tokens: List[int] = (
10031023
(
1024+
[suffix_token_id]
1025+
+
10041026
(
1005-
[suffix_token_id]
1006-
+
1007-
(
1008-
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)
1009-
if suffix
1010-
else []
1011-
)
1027+
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[suffix_space_prefix:]
1028+
if suffix
1029+
else []
10121030
)
1013-
if suffix_token_id >= 0 and suffix is not None
1014-
else []
1015-
)
1016-
+
1017-
(
1018-
[middle_token_id]
1019-
if middle_token_id >= 0 and suffix is not None
1020-
else []
10211031
)
1032+
if suffix_token_id >= 0 and suffix is not None
1033+
else []
1034+
)
1035+
middle_tokens: List[int] = (
1036+
[middle_token_id]
1037+
if middle_token_id >= 0 and suffix is not None
1038+
else []
10221039
)
1040+
prompt_tokens: List[int] = bos_tokens + ((suffix_tokens + prefix_tokens + middle_tokens) if self.spm_infill else (prefix_tokens + suffix_tokens + middle_tokens)) + eos_tokens
10231041
text: bytes = b""
10241042
returned_tokens: int = 0
10251043
stop = (
@@ -1176,7 +1194,7 @@ def logit_bias_processor(
11761194
# not sure how to handle this branch when dealing
11771195
# with CJK output, so keep it unchanged
11781196
for token in remaining_tokens:
1179-
if token == self.token_bos():
1197+
if token == bos_token_id:
11801198
continue
11811199
token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens]))
11821200
# Check if stop sequence is in the token
@@ -1303,7 +1321,7 @@ def logit_bias_processor(
13031321

13041322
logprobs_or_none: Optional[CompletionLogprobs] = None
13051323
if logprobs is not None:
1306-
if token == self.token_bos():
1324+
if token == bos_token_id:
13071325
continue
13081326
token_str = self.detokenize([token]).decode(
13091327
"utf-8", errors="ignore"
@@ -1431,7 +1449,7 @@ def logit_bias_processor(
14311449
for idx, (token, token_str, logprobs_token) in enumerate(
14321450
zip(all_tokens, all_token_strs, all_logprobs)
14331451
):
1434-
if token == self.token_bos():
1452+
if token == bos_token_id:
14351453
continue
14361454
text_offsets.append(
14371455
text_offset
@@ -1858,6 +1876,7 @@ def __getstate__(self):
18581876
type_k=self.context_params.type_k,
18591877
type_v=self.context_params.type_v,
18601878
# Misc
1879+
spm_infill=self.spm_infill,
18611880
verbose=self.verbose,
18621881
)
18631882

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.