@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141
141
if _key is None :
142
142
raise KeyError ("Key not found" )
143
143
value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144
- self .cache .push (_key , side = "front" ) # type: ignore
144
+ # NOTE: This puts an integer as key in cache, which breaks,
145
+ # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146
+ # self.cache.push(_key, side="front") # type: ignore
145
147
return value
146
148
147
149
def __contains__ (self , key : Sequence [int ]) -> bool :
@@ -168,7 +170,7 @@ def __init__(
168
170
eval_logits : Deque [List [float ]],
169
171
input_ids : npt .NDArray [np .intc ],
170
172
scores : npt .NDArray [np .single ],
171
- llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
173
+ llama_state : bytes ,
172
174
llama_state_size : int ,
173
175
):
174
176
self .eval_tokens = eval_tokens
@@ -1512,7 +1514,7 @@ def save_state(self) -> LlamaState:
1512
1514
eval_logits = self .eval_logits .copy (),
1513
1515
scores = self ._scores .copy (),
1514
1516
input_ids = self ._input_ids .copy (),
1515
- llama_state = llama_state_compact ,
1517
+ llama_state = bytes ( llama_state_compact ) ,
1516
1518
llama_state_size = n_bytes ,
1517
1519
)
1518
1520
@@ -1523,7 +1525,10 @@ def load_state(self, state: LlamaState) -> None:
1523
1525
self ._scores = state .scores .copy ()
1524
1526
self ._input_ids = state .input_ids .copy ()
1525
1527
state_size = state .llama_state_size
1526
- if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1528
+ LLamaStateArrayType = (llama_cpp .c_uint8 * state_size )
1529
+ llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
1530
+
1531
+ if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
1527
1532
raise RuntimeError ("Failed to set llama state data" )
1528
1533
1529
1534
def n_ctx (self ) -> int :
0 commit comments