Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 628e3fb

Browse filesBrowse files
authored
Merge pull request abetlen#370 from Okabintaro/fix-state-pickle
fix: Make LLamaState pickleable for disk cache
2 parents 04d9218 + 5eb4ebb commit 628e3fb
Copy full SHA for 628e3fb

File tree

Expand file treeCollapse file tree

1 file changed

+9
-4
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+9
-4
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+9-4Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141
if _key is None:
142142
raise KeyError("Key not found")
143143
value: "LlamaState" = self.cache.pop(_key) # type: ignore
144-
self.cache.push(_key, side="front") # type: ignore
144+
# NOTE: This puts an integer as key in cache, which breaks,
145+
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146+
# self.cache.push(_key, side="front") # type: ignore
145147
return value
146148

147149
def __contains__(self, key: Sequence[int]) -> bool:
@@ -168,7 +170,7 @@ def __init__(
168170
eval_logits: Deque[List[float]],
169171
input_ids: npt.NDArray[np.intc],
170172
scores: npt.NDArray[np.single],
171-
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
173+
llama_state: bytes,
172174
llama_state_size: int,
173175
):
174176
self.eval_tokens = eval_tokens
@@ -1509,7 +1511,7 @@ def save_state(self) -> LlamaState:
15091511
eval_logits=self.eval_logits.copy(),
15101512
scores=self._scores.copy(),
15111513
input_ids=self._input_ids.copy(),
1512-
llama_state=llama_state_compact,
1514+
llama_state=bytes(llama_state_compact),
15131515
llama_state_size=n_bytes,
15141516
)
15151517

@@ -1520,7 +1522,10 @@ def load_state(self, state: LlamaState) -> None:
15201522
self._scores = state.scores.copy()
15211523
self._input_ids = state.input_ids.copy()
15221524
state_size = state.llama_state_size
1523-
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
1525+
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
1526+
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
1527+
1528+
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
15241529
raise RuntimeError("Failed to set llama state data")
15251530

15261531
def n_ctx(self) -> int:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.