1
+ import ctypes
1
2
import pickle
2
3
import sys
3
4
from abc import ABC , abstractmethod
4
5
from collections import OrderedDict
5
6
from typing import Optional , Sequence , Tuple
6
7
7
8
import diskcache
9
+ import numpy as np
8
10
import pytrie
9
11
10
12
import llama_cpp .llama
11
13
12
14
from .llama_types import *
13
15
14
16
17
+ class StateReloadError (Exception ):
18
+ """
19
+ Error for when state from cache cannot be read by current model.
20
+ """
21
+
22
+
15
23
class BaseLlamaCache (ABC ):
16
24
"""Base cache class for a llama.cpp model."""
17
25
@@ -48,6 +56,20 @@ def __setitem__(
48
56
) -> None :
49
57
raise NotImplementedError
50
58
59
+ @classmethod
60
+ def reload_from_cache_state (
61
+ cls , model : "llama_cpp.llama.Llama" , state : "llama_cpp.llama.LlamaState"
62
+ ) -> None :
63
+ """
64
+ Reload the state onto the model. Normally this is done with load_state
65
+ (as state is created with the corresponding `save_state`), but for some
66
+ caches may need special handling as an optimization.
67
+
68
+ Throws a StateReloadError if the state is not compatible with the model
69
+ (for example, logits )
70
+ """
71
+ model .load_state (state )
72
+
51
73
52
74
class LlamaRAMCache (BaseLlamaCache ):
53
75
"""Cache for a llama.cpp model using RAM."""
@@ -223,6 +245,7 @@ def build_cache(
223
245
capacity_bytes : int = 2 << 30 ,
224
246
seed : Optional [int ] = None ,
225
247
add_bos = True ,
248
+ save_logits : bool = False ,
226
249
) -> "LlamaStaticDiskCache" :
227
250
"""
228
251
Using model passed in, evaluates each prompt and stores LlamaState in cache.
@@ -246,6 +269,19 @@ def build_cache(
246
269
print ("LlamaStaticDiskCache.build_cache: eval" , file = sys .stderr )
247
270
model .eval (eval_toks )
248
271
state = model .save_state ()
272
+
273
+ if not save_logits :
274
+ if (
275
+ model .context_params .logits_all
276
+ or model .draft_model is not None
277
+ or model .context_params .embeddings
278
+ ):
279
+ # Erroring instead of falling back to just saving with scores
280
+ raise ValueError (
281
+ "Cannot save state without logits - model requires logits to sample."
282
+ )
283
+ state .scores = None
284
+
249
285
cache ._private_setitem (toks , state ) # pylint: disable=protected-access
250
286
251
287
# Set up Trie for efficient prefix search
@@ -278,3 +314,83 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
278
314
def __setitem__ (self , key : Sequence [int ], value : "llama_cpp.llama.LlamaState" ):
279
315
# Should this just be a warning?
280
316
raise ValueError ("Cannot set items in a static cache" )
317
+
318
+ @classmethod
319
+ def reload_from_cache_state (
320
+ cls , model : "llama_cpp.llama.Llama" , state : "llama_cpp.llama.LlamaState"
321
+ ) -> None :
322
+ """
323
+ Skip reloading logits and set last logits from llama.cpp context struct
324
+ as the scores for last token of prompt.
325
+ """
326
+ # pylint: disable=protected-access
327
+
328
+ # Check if model needs logits (draft model, log probs required, etc.)
329
+ model_needs_scores_to_reload = (
330
+ # May be overly pessimistic if don't want embeddings for prompt tokens.
331
+ model .context_params .embeddings
332
+ or model .context_params .logits_all
333
+ # Same: is this really a hard requirement? We need token IDs from
334
+ # draft model and all the logits from base model to do verification
335
+ # of candidate tokens, but not for prompt tokens.
336
+ or model .draft_model is not None
337
+ )
338
+
339
+ if model_needs_scores_to_reload :
340
+ if state .scores is None :
341
+ raise StateReloadError (
342
+ "Model requires logits to be reloaded, but static cache does not store logits"
343
+ )
344
+ else :
345
+ model .load_state (state )
346
+ return
347
+
348
+ # Case where don't need logits from numpy and can just get last-token
349
+ # logits from llama.cpp struct
350
+ model .n_tokens = state .n_tokens
351
+ model .input_ids = state .input_ids .copy ()
352
+ model .scores [:] = 0.0
353
+
354
+ state_size = state .llama_state_size
355
+
356
+ try :
357
+ llama_state_array_type = ctypes .c_uint8 * state_size
358
+ # Have to do from_buffer_copy since LlamaState.llama_state is
359
+ # non-mutable bytes, not mutable bytearray.
360
+ llama_state = llama_state_array_type .from_buffer_copy (state .llama_state )
361
+ reloaded_state_size = llama_cpp .llama_set_state_data (
362
+ model ._ctx .ctx , llama_state
363
+ )
364
+
365
+ if reloaded_state_size != state_size :
366
+ raise StateReloadError (
367
+ "Failed to set llama state data - reloaded state size "
368
+ f"{ reloaded_state_size } does not match original size { state_size } "
369
+ )
370
+
371
+ # cffi dtype, compatible w/ numpy through ducktyping :scared:
372
+ dtype = llama_cpp .llama_cpp .llama_get_logits_ith .restype ._type_
373
+
374
+ # If model scores dtype doesn't match dtype from sig, then can't
375
+ # copy it.
376
+ if model .scores .dtype != dtype :
377
+ raise StateReloadError (
378
+ f"Expected scores to be { dtype } but got "
379
+ f"{ model .scores .dtype } - are you running this in the future? Or the past?"
380
+ )
381
+
382
+ # Will have a ValueError for null pointers
383
+ last_position_logits = np .array (
384
+ ctypes .cast (
385
+ model ._ctx .get_logits_ith (- 1 ),
386
+ ctypes .POINTER (dtype * model .n_vocab ()),
387
+ ).contents ,
388
+ # Otherwise will be a view into C array on llama.cpp context
389
+ copy = True ,
390
+ dtype = dtype ,
391
+ )
392
+
393
+ model ._scores [- 1 , :] = last_position_logits
394
+
395
+ except ValueError as e :
396
+ raise StateReloadError from e
0 commit comments