Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3379dc4

Browse filesBrowse files
committed
Merge branch 'main' of github.com:abetlen/llama_cpp_python into main
2 parents 9522284 + 628e3fb commit 3379dc4
Copy full SHA for 3379dc4

File tree

Expand file treeCollapse file tree

1 file changed

+9
-4
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+9
-4
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+9-4Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141
if _key is None:
142142
raise KeyError("Key not found")
143143
value: "LlamaState" = self.cache.pop(_key) # type: ignore
144-
self.cache.push(_key, side="front") # type: ignore
144+
# NOTE: This puts an integer as key in cache, which breaks,
145+
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146+
# self.cache.push(_key, side="front") # type: ignore
145147
return value
146148

147149
def __contains__(self, key: Sequence[int]) -> bool:
@@ -168,7 +170,7 @@ def __init__(
168170
eval_logits: Deque[List[float]],
169171
input_ids: npt.NDArray[np.intc],
170172
scores: npt.NDArray[np.single],
171-
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
173+
llama_state: bytes,
172174
llama_state_size: int,
173175
):
174176
self.eval_tokens = eval_tokens
@@ -1512,7 +1514,7 @@ def save_state(self) -> LlamaState:
15121514
eval_logits=self.eval_logits.copy(),
15131515
scores=self._scores.copy(),
15141516
input_ids=self._input_ids.copy(),
1515-
llama_state=llama_state_compact,
1517+
llama_state=bytes(llama_state_compact),
15161518
llama_state_size=n_bytes,
15171519
)
15181520

@@ -1523,7 +1525,10 @@ def load_state(self, state: LlamaState) -> None:
15231525
self._scores = state.scores.copy()
15241526
self._input_ids = state.input_ids.copy()
15251527
state_size = state.llama_state_size
1526-
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
1528+
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
1529+
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
1530+
1531+
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
15271532
raise RuntimeError("Failed to set llama state data")
15281533

15291534
def n_ctx(self) -> int:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.