Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 01a010b

Browse filesBrowse files
committed
Fix llama_cpp and Llama type signatures. Closes abetlen#221
1 parent fb57b94 commit 01a010b
Copy full SHA for 01a010b

File tree

Expand file treeCollapse file tree

3 files changed

+58
-64
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+58
-64
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+35-41Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ class LlamaCache:
1515
"""Cache for a llama.cpp model."""
1616

1717
def __init__(self, capacity_bytes: int = (2 << 30)):
18-
self.cache_state: OrderedDict[
19-
Tuple[llama_cpp.llama_token, ...], "LlamaState"
20-
] = OrderedDict()
18+
self.cache_state: OrderedDict[Tuple[int, ...], "LlamaState"] = OrderedDict()
2119
self.capacity_bytes = capacity_bytes
2220

2321
@property
@@ -26,8 +24,8 @@ def cache_size(self):
2624

2725
def _find_longest_prefix_key(
2826
self,
29-
key: Tuple[llama_cpp.llama_token, ...],
30-
) -> Optional[Tuple[llama_cpp.llama_token, ...]]:
27+
key: Tuple[int, ...],
28+
) -> Optional[Tuple[int, ...]]:
3129
min_len = 0
3230
min_key = None
3331
keys = (
@@ -39,7 +37,7 @@ def _find_longest_prefix_key(
3937
min_key = k
4038
return min_key
4139

42-
def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
40+
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
4341
key = tuple(key)
4442
_key = self._find_longest_prefix_key(key)
4543
if _key is None:
@@ -48,10 +46,10 @@ def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
4846
self.cache_state.move_to_end(_key)
4947
return value
5048

51-
def __contains__(self, key: Sequence[llama_cpp.llama_token]) -> bool:
49+
def __contains__(self, key: Sequence[int]) -> bool:
5250
return self._find_longest_prefix_key(tuple(key)) is not None
5351

54-
def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState"):
52+
def __setitem__(self, key: Sequence[int], value: "LlamaState"):
5553
key = tuple(key)
5654
if key in self.cache_state:
5755
del self.cache_state[key]
@@ -63,7 +61,7 @@ def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState")
6361
class LlamaState:
6462
def __init__(
6563
self,
66-
eval_tokens: Deque[llama_cpp.llama_token],
64+
eval_tokens: Deque[int],
6765
eval_logits: Deque[List[float]],
6866
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
6967
llama_state_size: int,
@@ -141,7 +139,7 @@ def __init__(
141139

142140
self.last_n_tokens_size = last_n_tokens_size
143141
self.n_batch = min(n_ctx, n_batch)
144-
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
142+
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
145143
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
146144

147145
self.cache: Optional[LlamaCache] = None
@@ -176,9 +174,7 @@ def __init__(
176174
if self.verbose:
177175
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
178176

179-
def tokenize(
180-
self, text: bytes, add_bos: bool = True
181-
) -> List[llama_cpp.llama_token]:
177+
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
182178
"""Tokenize a string.
183179
184180
Args:
@@ -197,7 +193,7 @@ def tokenize(
197193
self.ctx,
198194
text,
199195
tokens,
200-
n_ctx,
196+
llama_cpp.c_int(n_ctx),
201197
llama_cpp.c_bool(add_bos),
202198
)
203199
if int(n_tokens) < 0:
@@ -216,7 +212,7 @@ def tokenize(
216212
)
217213
return list(tokens[:n_tokens])
218214

219-
def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
215+
def detokenize(self, tokens: List[int]) -> bytes:
220216
"""Detokenize a list of tokens.
221217
222218
Args:
@@ -228,7 +224,9 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
228224
assert self.ctx is not None
229225
output = b""
230226
for token in tokens:
231-
output += llama_cpp.llama_token_to_str(self.ctx, token)
227+
output += llama_cpp.llama_token_to_str(
228+
self.ctx, llama_cpp.llama_token(token)
229+
)
232230
return output
233231

234232
def set_cache(self, cache: Optional[LlamaCache]):
@@ -244,7 +242,7 @@ def reset(self):
244242
self.eval_tokens.clear()
245243
self.eval_logits.clear()
246244

247-
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
245+
def eval(self, tokens: Sequence[int]):
248246
"""Evaluate a list of tokens.
249247
250248
Args:
@@ -458,7 +456,7 @@ def sample(
458456

459457
def generate(
460458
self,
461-
tokens: Sequence[llama_cpp.llama_token],
459+
tokens: Sequence[int],
462460
top_k: int = 40,
463461
top_p: float = 0.95,
464462
temp: float = 0.80,
@@ -470,9 +468,7 @@ def generate(
470468
mirostat_mode: int = 0,
471469
mirostat_tau: float = 5.0,
472470
mirostat_eta: float = 0.1,
473-
) -> Generator[
474-
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
475-
]:
471+
) -> Generator[int, Optional[Sequence[int]], None]:
476472
"""Create a generator of tokens from a prompt.
477473
478474
Examples:
@@ -617,14 +613,14 @@ def _create_completion(
617613
assert self.ctx is not None
618614
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
619615
created: int = int(time.time())
620-
completion_tokens: List[llama_cpp.llama_token] = []
616+
completion_tokens: List[int] = []
621617
# Add blank space to start of prompt to match OG llama tokenizer
622-
prompt_tokens: List[llama_cpp.llama_token] = self.tokenize(
623-
b" " + prompt.encode("utf-8")
624-
)
618+
prompt_tokens: List[int] = self.tokenize(b" " + prompt.encode("utf-8"))
625619
text: bytes = b""
626620
returned_tokens: int = 0
627-
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
621+
stop = (
622+
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
623+
)
628624
model_name: str = model if model is not None else self.model_path
629625

630626
if self.verbose:
@@ -724,7 +720,9 @@ def _create_completion(
724720
for token in remaining_tokens:
725721
token_end_position += len(self.detokenize([token]))
726722
# Check if stop sequence is in the token
727-
if token_end_position >= (remaining_length - first_stop_position - 1):
723+
if token_end_position >= (
724+
remaining_length - first_stop_position - 1
725+
):
728726
break
729727
logprobs_or_none: Optional[CompletionLogprobs] = None
730728
if logprobs is not None:
@@ -744,7 +742,7 @@ def _create_completion(
744742
)
745743
)
746744
top_logprob = {
747-
self.detokenize([llama_cpp.llama_token(i)]).decode(
745+
self.detokenize([i]).decode(
748746
"utf-8", errors="ignore"
749747
): logprob
750748
for logprob, i in sorted_logprobs[:logprobs]
@@ -822,9 +820,7 @@ def _create_completion(
822820
)
823821
)
824822
top_logprob = {
825-
self.detokenize([llama_cpp.llama_token(i)]).decode(
826-
"utf-8", errors="ignore"
827-
): logprob
823+
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
828824
for logprob, i in sorted_logprobs[:logprobs]
829825
}
830826
top_logprob.update({token_str: current_logprobs[int(token)]})
@@ -924,9 +920,7 @@ def _create_completion(
924920
)
925921
token_logprobs.append(sorted_logprobs[int(token)][0])
926922
top_logprob: Optional[Dict[str, float]] = {
927-
self.detokenize([llama_cpp.llama_token(i)]).decode(
928-
"utf-8", errors="ignore"
929-
): logprob
923+
self.detokenize([i]).decode("utf-8", errors="ignore"): logprob
930924
for logprob, i in sorted_logprobs[:logprobs]
931925
}
932926
top_logprob.update({token_str: logprobs_token[int(token)]})
@@ -1188,7 +1182,9 @@ def create_chat_completion(
11881182
Returns:
11891183
Generated chat completion or a stream of chat completion chunks.
11901184
"""
1191-
stop = stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1185+
stop = (
1186+
stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else []
1187+
)
11921188
chat_history = "".join(
11931189
f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}'
11941190
for message in messages
@@ -1296,17 +1292,17 @@ def load_state(self, state: LlamaState) -> None:
12961292
raise RuntimeError("Failed to set llama state data")
12971293

12981294
@staticmethod
1299-
def token_eos() -> llama_cpp.llama_token:
1295+
def token_eos() -> int:
13001296
"""Return the end-of-sequence token."""
13011297
return llama_cpp.llama_token_eos()
13021298

13031299
@staticmethod
1304-
def token_bos() -> llama_cpp.llama_token:
1300+
def token_bos() -> int:
13051301
"""Return the beginning-of-sequence token."""
13061302
return llama_cpp.llama_token_bos()
13071303

13081304
@staticmethod
1309-
def token_nl() -> llama_cpp.llama_token:
1305+
def token_nl() -> int:
13101306
"""Return the newline token."""
13111307
return llama_cpp.llama_token_nl()
13121308

@@ -1317,9 +1313,7 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
13171313
return [math.log(x / sum_exps) for x in exps]
13181314

13191315
@staticmethod
1320-
def longest_token_prefix(
1321-
a: Sequence[llama_cpp.llama_token], b: Sequence[llama_cpp.llama_token]
1322-
):
1316+
def longest_token_prefix(a: Sequence[int], b: Sequence[int]):
13231317
longest_prefix = 0
13241318
for _a, _b in zip(a, b):
13251319
if _a == _b:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.