Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 197cf80

Browse filesBrowse files
committed
Add save/load state api for Llama class
1 parent c4c332f commit 197cf80
Copy full SHA for 197cf80

File tree

Expand file treeCollapse file tree

1 file changed

+39
-4
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+39
-4
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+39-4Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import math
66
import multiprocessing
7-
from typing import List, Optional, Union, Generator, Sequence, Iterator
7+
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
88
from collections import deque
99

1010
from . import llama_cpp
@@ -20,6 +20,18 @@ class LlamaCache:
2020
pass
2121

2222

23+
class LlamaState:
24+
def __init__(
25+
self,
26+
eval_tokens: Deque[llama_cpp.llama_token],
27+
eval_logits: Deque[List[float]],
28+
llama_state,
29+
):
30+
self.eval_tokens = eval_tokens
31+
self.eval_logits = eval_logits
32+
self.llama_state = llama_state
33+
34+
2335
class Llama:
2436
"""High-level Python wrapper for a llama.cpp model."""
2537

@@ -85,8 +97,8 @@ def __init__(
8597

8698
self.last_n_tokens_size = last_n_tokens_size
8799
self.n_batch = min(n_ctx, n_batch)
88-
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
89-
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
100+
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
101+
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
90102

91103
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
92104
### saving and restoring state, this allows us to continue a completion if the last
@@ -204,7 +216,10 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
204216
cols = int(n_vocab)
205217
rows = n_tokens
206218
logits_view = llama_cpp.llama_get_logits(self.ctx)
207-
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
219+
logits = [
220+
[logits_view[i * cols + j] for j in range(cols)]
221+
for i in range(rows)
222+
]
208223
self.eval_logits.extend(logits)
209224

210225
def sample(
@@ -828,6 +843,26 @@ def __setstate__(self, state):
828843
verbose=state["verbose"],
829844
)
830845

846+
def save_state(self) -> LlamaState:
847+
assert self.ctx is not None
848+
state_size = llama_cpp.llama_get_state_size(self.ctx)
849+
llama_state = (llama_cpp.c_uint8 * int(state_size))()
850+
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
851+
raise RuntimeError("Failed to copy llama state data")
852+
return LlamaState(
853+
eval_tokens=self.eval_tokens.copy(),
854+
eval_logits=self.eval_logits.copy(),
855+
llama_state=llama_state,
856+
)
857+
858+
def load_state(self, state: LlamaState) -> None:
859+
assert self.ctx is not None
860+
self.eval_tokens = state.eval_tokens.copy()
861+
self.eval_logits = state.eval_logits.copy()
862+
state_size = llama_cpp.llama_get_state_size(self.ctx)
863+
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
864+
raise RuntimeError("Failed to set llama state data")
865+
831866
@staticmethod
832867
def token_eos() -> llama_cpp.llama_token:
833868
"""Return the end-of-sequence token."""

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.