1
- import ctypes
2
1
import pickle
2
+ import ctypes
3
3
import sys
4
4
from abc import ABC , abstractmethod
5
5
from collections import OrderedDict
6
6
from typing import Optional , Sequence , Tuple
7
7
8
8
import diskcache
9
- import numpy as np
10
9
import pytrie
11
10
12
11
import llama_cpp .llama
@@ -77,9 +76,9 @@ class LlamaRAMCache(BaseLlamaCache):
77
76
def __init__ (self , capacity_bytes : int = (2 << 30 )):
78
77
super ().__init__ (capacity_bytes )
79
78
self .capacity_bytes = capacity_bytes
80
- self .cache_state : OrderedDict [
81
- Tuple [ int , ...], "llama_cpp.llama.LlamaState"
82
- ] = OrderedDict ( )
79
+ self .cache_state : OrderedDict [Tuple [ int , ...], "llama_cpp.llama.LlamaState" ] = (
80
+ OrderedDict ()
81
+ )
83
82
84
83
@property
85
84
def cache_size (self ):
@@ -320,8 +319,8 @@ def reload_from_cache_state(
320
319
cls , model : "llama_cpp.llama.Llama" , state : "llama_cpp.llama.LlamaState"
321
320
) -> None :
322
321
"""
323
- Skip reloading logits and set last logits from llama.cpp context struct
324
- as the scores for last token of prompt .
322
+ Skip reloading logits (zero-out instead) unless `logits_all` or
323
+ otherwise needed .
325
324
"""
326
325
# pylint: disable=protected-access
327
326
@@ -349,17 +348,40 @@ def reload_from_cache_state(
349
348
# logits from llama.cpp struct
350
349
model .n_tokens = state .n_tokens
351
350
model .input_ids = state .input_ids .copy ()
352
- model .scores [:] = 0.0
351
+ model ._seed = state .seed
352
+
353
+ if model .scores .shape [0 ] < state .n_tokens :
354
+ raise StateReloadError (
355
+ f"Model context / batch size { model .scores .shape [0 ]} not large "
356
+ f"enough for saved state tokens { state .n_tokens } ."
357
+ )
353
358
354
359
state_size = state .llama_state_size
355
360
361
+ LlamaStateArrayType = ctypes .c_uint8 * state_size
362
+ llama_state = LlamaStateArrayType .from_buffer_copy (state .llama_state )
363
+
364
+ # Use non-deprecated llama_state_set_data over llama_set_state_data
365
+ if (
366
+ bytes_set := llama_cpp .llama_state_set_data (
367
+ model ._ctx .ctx , llama_state , ctypes .sizeof (llama_state )
368
+ ),
369
+ ) != state_size :
370
+ raise RuntimeError (
371
+ "Failed to set llama state data - mismatch between bytes set "
372
+ f"{ bytes_set } and state size { state_size } "
373
+ )
374
+
375
+ # No longer need to reload scores, since now use llama.cpp sampler.
376
+ # pylint: disable=pointless-string-statement
377
+ """
356
378
try:
357
379
llama_state_array_type = ctypes.c_uint8 * state_size
358
380
# Have to do from_buffer_copy since LlamaState.llama_state is
359
381
# non-mutable bytes, not mutable bytearray.
360
382
llama_state = llama_state_array_type.from_buffer_copy(state.llama_state)
361
- reloaded_state_size = llama_cpp .llama_set_state_data (
362
- model ._ctx .ctx , llama_state
383
+ reloaded_state_size = llama_cpp.llama_state_set_data (
384
+ model._ctx.ctx, llama_state, ctypes.sizeof(llama_state)
363
385
)
364
386
365
387
if reloaded_state_size != state_size:
@@ -394,3 +416,4 @@ def reload_from_cache_state(
394
416
395
417
except ValueError as e:
396
418
raise StateReloadError from e
419
+ """
0 commit comments