Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9b64bb5

Browse filesBrowse files
committed
misc: Format
1 parent 1e64664 commit 9b64bb5
Copy full SHA for 9b64bb5

File tree

Expand file treeCollapse file tree

6 files changed

+147
-80
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+147
-80
lines changed

‎llama_cpp/_internals.py

Copy file name to clipboardExpand all lines: llama_cpp/_internals.py
+50-16Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def n_params(self) -> int:
100100
def get_tensor(self, name: str) -> ctypes.c_void_p:
101101
return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8"))
102102

103-
104103
# Vocab
105104

106105
def token_get_text(self, token: int) -> str:
@@ -460,9 +459,7 @@ def __init__(
460459
self.verbose = verbose
461460
self._exit_stack = ExitStack()
462461

463-
batch = llama_cpp.llama_batch_init(
464-
self._n_tokens, self.embd, self.n_seq_max
465-
)
462+
batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
466463

467464
if batch is None:
468465
raise ValueError("Failed to create llama_batch")
@@ -541,6 +538,7 @@ def copy_logits(self, logits: npt.NDArray[np.single]):
541538

542539
# Embedding functions
543540

541+
544542
def normalize_embedding(embedding):
545543
norm = float(np.linalg.norm(embedding))
546544
if norm == 0.0:
@@ -713,11 +711,17 @@ def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
713711
import ctypes
714712
import llama_cpp
715713

714+
716715
class CustomSampler:
717-
def __init__(self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]):
716+
def __init__(
717+
self, apply_func: typing.Callable[[llama_cpp.llama_token_data_array], None]
718+
):
718719
self.apply_func = apply_func
719720

720-
def apply_wrapper(sampler: llama_cpp.llama_sampler_p, cur_p: llama_cpp.llama_token_data_array_p):
721+
def apply_wrapper(
722+
sampler: llama_cpp.llama_sampler_p,
723+
cur_p: llama_cpp.llama_token_data_array_p,
724+
):
721725
self.apply_func(cur_p)
722726

723727
def free_wrapper(sampler: llama_cpp.llama_sampler_p):
@@ -740,6 +744,7 @@ def free_wrapper(sampler: llama_cpp.llama_sampler_p):
740744
def get_sampler(self) -> llama_cpp.llama_sampler_p:
741745
return ctypes.pointer(self.sampler)
742746

747+
743748
class LlamaSampler:
744749
def __init__(self):
745750
params = llama_cpp.llama_sampler_chain_params()
@@ -788,33 +793,62 @@ def add_temp_ext(self, t: float, delta: float, exponent: float):
788793
self._add_sampler(sampler)
789794

790795
def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
791-
sampler = llama_cpp.llama_sampler_init_mirostat(
792-
n_vocab, seed, tau, eta, m
793-
)
796+
sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
794797
self._add_sampler(sampler)
795798

796799
def add_mirostat_v2(self, seed: int, tau: float, eta: float):
797800
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
798801
self._add_sampler(sampler)
799802

800803
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
801-
sampler = llama_cpp.llama_sampler_init_grammar(model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8"))
804+
sampler = llama_cpp.llama_sampler_init_grammar(
805+
model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
806+
)
802807
self._add_sampler(sampler)
803808

804-
def add_penalties(self, n_vocab: int, special_eos_id: int, linefeed_id: int, penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, penalize_nl: bool, ignore_eos: bool):
805-
sampler = llama_cpp.llama_sampler_init_penalties(n_vocab, special_eos_id, linefeed_id, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos)
809+
def add_penalties(
810+
self,
811+
n_vocab: int,
812+
special_eos_id: int,
813+
linefeed_id: int,
814+
penalty_last_n: int,
815+
penalty_repeat: float,
816+
penalty_freq: float,
817+
penalty_present: float,
818+
penalize_nl: bool,
819+
ignore_eos: bool,
820+
):
821+
sampler = llama_cpp.llama_sampler_init_penalties(
822+
n_vocab,
823+
special_eos_id,
824+
linefeed_id,
825+
penalty_last_n,
826+
penalty_repeat,
827+
penalty_freq,
828+
penalty_present,
829+
penalize_nl,
830+
ignore_eos,
831+
)
806832
self._add_sampler(sampler)
807833

808-
def init_logit_bias(self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p):
809-
sampler = llama_cpp.llama_sampler_init_logit_bias(n_vocab, n_logit_bias, logit_bias)
834+
def init_logit_bias(
835+
self, n_vocab: int, n_logit_bias, logit_bias: llama_cpp.llama_logit_bias_p
836+
):
837+
sampler = llama_cpp.llama_sampler_init_logit_bias(
838+
n_vocab, n_logit_bias, logit_bias
839+
)
810840
self._add_sampler(sampler)
811841

812-
def add_custom(self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]):
842+
def add_custom(
843+
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
844+
):
813845
custom_sampler = CustomSampler(apply_func)
814846
sampler = custom_sampler.get_sampler()
815847
self._add_sampler(sampler)
816848
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
817-
self.custom_samplers.append((llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler))
849+
self.custom_samplers.append(
850+
(llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
851+
)
818852

819853
def _add_sampler(self, sampler: llama_cpp.llama_sampler_p):
820854
assert self.sampler is not None

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+56-42Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -255,28 +255,28 @@ def __init__(
255255
for i, (k, v) in enumerate(kv_overrides.items()):
256256
self._kv_overrides_array[i].key = k.encode("utf-8")
257257
if isinstance(v, bool):
258-
self._kv_overrides_array[i].tag = (
259-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
260-
)
258+
self._kv_overrides_array[
259+
i
260+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
261261
self._kv_overrides_array[i].value.val_bool = v
262262
elif isinstance(v, int):
263-
self._kv_overrides_array[i].tag = (
264-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
265-
)
263+
self._kv_overrides_array[
264+
i
265+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
266266
self._kv_overrides_array[i].value.val_i64 = v
267267
elif isinstance(v, float):
268-
self._kv_overrides_array[i].tag = (
269-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
270-
)
268+
self._kv_overrides_array[
269+
i
270+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
271271
self._kv_overrides_array[i].value.val_f64 = v
272272
elif isinstance(v, str): # type: ignore
273273
v_bytes = v.encode("utf-8")
274274
if len(v_bytes) > 128: # TODO: Make this a constant
275275
raise ValueError(f"Value for {k} is too long: {v}")
276276
v_bytes = v_bytes.ljust(128, b"\0")
277-
self._kv_overrides_array[i].tag = (
278-
llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
279-
)
277+
self._kv_overrides_array[
278+
i
279+
].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
280280
# copy min(v_bytes, 128) to str_value
281281
address = typing.cast(
282282
int,
@@ -292,9 +292,9 @@ def __init__(
292292
else:
293293
raise ValueError(f"Unknown value type for {k}: {v}")
294294

295-
self._kv_overrides_array[-1].key = (
296-
b"\0" # ensure sentinel element is zeroed
297-
)
295+
self._kv_overrides_array[
296+
-1
297+
].key = b"\0" # ensure sentinel element is zeroed
298298
self.model_params.kv_overrides = self._kv_overrides_array
299299

300300
self.n_batch = min(n_ctx, n_batch) # ???
@@ -431,9 +431,9 @@ def free_lora_adapter():
431431

432432
self.chat_format = chat_format
433433
self.chat_handler = chat_handler
434-
self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = (
435-
{}
436-
)
434+
self._chat_handlers: Dict[
435+
str, llama_chat_format.LlamaChatCompletionHandler
436+
] = {}
437437

438438
self.draft_model = draft_model
439439

@@ -580,7 +580,10 @@ def tokenize(
580580
return self.tokenizer_.tokenize(text, add_bos, special)
581581

582582
def detokenize(
583-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
583+
self,
584+
tokens: List[int],
585+
prev_tokens: Optional[List[int]] = None,
586+
special: bool = False,
584587
) -> bytes:
585588
"""Detokenize a list of tokens.
586589
@@ -592,7 +595,9 @@ def detokenize(
592595
Returns:
593596
The detokenized string.
594597
"""
595-
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special)
598+
return self.tokenizer_.detokenize(
599+
tokens, prev_tokens=prev_tokens, special=special
600+
)
596601

597602
def set_cache(self, cache: Optional[BaseLlamaCache]):
598603
"""Set the cache.
@@ -681,12 +686,16 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
681686
recarray = np.recarray(
682687
shape=(size,),
683688
dtype=np.dtype(
684-
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
689+
[("id", np.intc), ("logit", np.single), ("p", np.single)],
690+
align=True,
691+
),
692+
buf=(llama_cpp.llama_token_data * size).from_address(
693+
data_soa_address
685694
),
686-
buf=(llama_cpp.llama_token_data * size).from_address(data_soa_address),
687695
)
688696
for logit_processor in logits_processor:
689697
recarray.logit[:] = logit_processor(self._input_ids, recarray.logit)
698+
690699
sampler.add_custom(apply_func)
691700

692701
sampler.add_penalties(
@@ -698,7 +707,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
698707
penalty_freq=frequency_penalty,
699708
penalty_present=presence_penalty,
700709
penalize_nl=penalize_nl,
701-
ignore_eos=False
710+
ignore_eos=False,
702711
)
703712

704713
if grammar is not None:
@@ -841,22 +850,22 @@ def generate(
841850
# Reset mirostat sampling
842851
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
843852
self._sampler = self._init_sampler(
844-
top_k=top_k,
845-
top_p=top_p,
846-
min_p=min_p,
847-
typical_p=typical_p,
848-
temp=temp,
849-
repeat_penalty=repeat_penalty,
850-
frequency_penalty=frequency_penalty,
851-
presence_penalty=presence_penalty,
852-
tfs_z=tfs_z,
853-
mirostat_mode=mirostat_mode,
854-
mirostat_tau=mirostat_tau,
855-
mirostat_eta=mirostat_eta,
856-
penalize_nl=penalize_nl,
857-
logits_processor=logits_processor,
858-
grammar=grammar,
859-
seed=seed,
853+
top_k=top_k,
854+
top_p=top_p,
855+
min_p=min_p,
856+
typical_p=typical_p,
857+
temp=temp,
858+
repeat_penalty=repeat_penalty,
859+
frequency_penalty=frequency_penalty,
860+
presence_penalty=presence_penalty,
861+
tfs_z=tfs_z,
862+
mirostat_mode=mirostat_mode,
863+
mirostat_tau=mirostat_tau,
864+
mirostat_eta=mirostat_eta,
865+
penalize_nl=penalize_nl,
866+
logits_processor=logits_processor,
867+
grammar=grammar,
868+
seed=seed,
860869
)
861870

862871
# Check for kv cache prefix match
@@ -872,8 +881,11 @@ def generate(
872881
tokens = tokens[longest_prefix:]
873882
self.n_tokens = longest_prefix
874883
if self.verbose:
875-
print(f"Llama.generate: {longest_prefix} prefix-match hit, "
876-
f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr)
884+
print(
885+
f"Llama.generate: {longest_prefix} prefix-match hit, "
886+
f"remaining {len(tokens)} prompt tokens to eval",
887+
file=sys.stderr,
888+
)
877889

878890
# Reset the model state
879891
if reset:
@@ -1032,7 +1044,9 @@ def decode_batch(seq_sizes: List[int]):
10321044
for j in range(size)
10331045
]
10341046
if normalize:
1035-
embedding = [internals.normalize_embedding(e) for e in embedding]
1047+
embedding = [
1048+
internals.normalize_embedding(e) for e in embedding
1049+
]
10361050
data.append(embedding)
10371051
pos += size
10381052
else:

‎llama_cpp/llama_cache.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cache.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ class LlamaRAMCache(BaseLlamaCache):
5252
def __init__(self, capacity_bytes: int = (2 << 30)):
5353
super().__init__(capacity_bytes)
5454
self.capacity_bytes = capacity_bytes
55-
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = (
56-
OrderedDict()
57-
)
55+
self.cache_state: OrderedDict[
56+
Tuple[int, ...], "llama_cpp.llama.LlamaState"
57+
] = OrderedDict()
5858

5959
@property
6060
def cache_size(self):

‎llama_cpp/llama_grammar.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_grammar.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
LLAMA_GRAMMAR_DEFAULT_ROOT = "root"
1717

18+
1819
class LlamaGrammar:
1920
def __init__(self, *args, _grammar: str, **kwargs):
2021
self._grammar = _grammar
@@ -23,7 +24,7 @@ def __init__(self, *args, _grammar: str, **kwargs):
2324
@classmethod
2425
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
2526
return cls(_grammar=grammar)
26-
27+
2728
@classmethod
2829
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
2930
try:

‎llama_cpp/llama_tokenizer.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_tokenizer.py
+22-11Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def tokenize(
2727

2828
@abc.abstractmethod
2929
def detokenize(
30-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
30+
self,
31+
tokens: List[int],
32+
prev_tokens: Optional[List[int]] = None,
33+
special: bool = False,
3134
) -> bytes:
3235
"""Detokenize the tokens into text.
3336
@@ -49,7 +52,10 @@ def tokenize(
4952
return self._model.tokenize(text, add_bos=add_bos, special=special)
5053

5154
def detokenize(
52-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
55+
self,
56+
tokens: List[int],
57+
prev_tokens: Optional[List[int]] = None,
58+
special: bool = False,
5359
) -> bytes:
5460
return self._model.detokenize(tokens, special=special)
5561

@@ -80,19 +86,24 @@ def tokenize(
8086
)
8187

8288
def detokenize(
83-
self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False
89+
self,
90+
tokens: List[int],
91+
prev_tokens: Optional[List[int]] = None,
92+
special: bool = False,
8493
) -> bytes:
85-
skip_special_tokens = not special
94+
skip_special_tokens = not special
8695
if prev_tokens is not None:
87-
text = self.hf_tokenizer.decode(prev_tokens + tokens, skip_special_tokens=skip_special_tokens).encode(
88-
"utf-8", errors="ignore"
89-
)
90-
prev_text = self.hf_tokenizer.decode(prev_tokens, skip_special_tokens=skip_special_tokens).encode(
91-
"utf-8", errors="ignore"
92-
)
96+
text = self.hf_tokenizer.decode(
97+
prev_tokens + tokens, skip_special_tokens=skip_special_tokens
98+
).encode("utf-8", errors="ignore")
99+
prev_text = self.hf_tokenizer.decode(
100+
prev_tokens, skip_special_tokens=skip_special_tokens
101+
).encode("utf-8", errors="ignore")
93102
return text[len(prev_text) :]
94103
else:
95-
return self.hf_tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens).encode("utf-8", errors="ignore")
104+
return self.hf_tokenizer.decode(
105+
tokens, skip_special_tokens=skip_special_tokens
106+
).encode("utf-8", errors="ignore")
96107

97108
@classmethod
98109
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.