Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 43f2907

Browse filesBrowse files
committed
Support smaller state sizes
1 parent 1d47cce commit 43f2907
Copy full SHA for 43f2907

File tree

Expand file treeCollapse file tree

1 file changed

+10
-4
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+10
-4
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+10-4Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@ class LlamaState:
5353
def __init__(
5454
self,
5555
eval_tokens: Deque[llama_cpp.llama_token],
56-
eval_logits: Deque[List[float]],
56+
eval_logits: Deque[List[llama_cpp.c_float]],
5757
llama_state,
58+
llama_state_size: llama_cpp.c_size_t,
5859
):
5960
self.eval_tokens = eval_tokens
6061
self.eval_logits = eval_logits
6162
self.llama_state = llama_state
63+
self.llama_state_size = llama_state_size
6264

6365

6466
class Llama:
@@ -950,19 +952,23 @@ def save_state(self) -> LlamaState:
950952
assert self.ctx is not None
951953
state_size = llama_cpp.llama_get_state_size(self.ctx)
952954
llama_state = (llama_cpp.c_uint8 * int(state_size))()
953-
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
955+
n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
956+
if int(n_bytes) > int(state_size):
954957
raise RuntimeError("Failed to copy llama state data")
958+
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
959+
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
955960
return LlamaState(
956961
eval_tokens=self.eval_tokens.copy(),
957962
eval_logits=self.eval_logits.copy(),
958-
llama_state=llama_state,
963+
llama_state=llama_state_compact,
964+
llama_state_size=n_bytes,
959965
)
960966

961967
def load_state(self, state: LlamaState) -> None:
962968
assert self.ctx is not None
963969
self.eval_tokens = state.eval_tokens.copy()
964970
self.eval_logits = state.eval_logits.copy()
965-
state_size = llama_cpp.llama_get_state_size(self.ctx)
971+
state_size = state.llama_state_size
966972
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
967973
raise RuntimeError("Failed to set llama state data")
968974

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.