Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6634017

Browse filesBrowse files
authored
Merge pull request #1 from tc-wolf/optimize_kv_cache_size
Optimize KV cache size
2 parents 121eaaa + 0967eda commit 6634017
Copy full SHA for 6634017

File tree

Expand file treeCollapse file tree

3 files changed

+375
-10
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+375
-10
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+18-10Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
LlamaDiskCache, # type: ignore
3939
LlamaStaticDiskCache, # type: ignore
4040
LlamaRAMCache, # type: ignore
41+
StateReloadError, # type: ignore
4142
)
4243

4344
import numpy as np
@@ -1234,16 +1235,23 @@ def logit_bias_processor(
12341235
file=sys.stderr,
12351236
)
12361237

1237-
before = time.time()
1238-
self.load_state(cache_item)
1239-
after = time.time()
1240-
if self.verbose:
1241-
print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr)
1242-
if self.verbose:
1243-
print(
1244-
f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}",
1245-
file=sys.stderr,
1246-
)
1238+
try:
1239+
before = time.time()
1240+
self.cache.reload_from_cache_state(self, cache_item)
1241+
after = time.time()
1242+
if self.verbose:
1243+
print("State loading took", round((after - before) * 1_000, 4), "ms", file=sys.stderr)
1244+
print(
1245+
f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}",
1246+
file=sys.stderr,
1247+
)
1248+
except StateReloadError as e:
1249+
if self.verbose:
1250+
print(
1251+
f"Llama._create_completion: cache hit with len {cache_prefix_len} / {len(prompt_tokens)}, but failed to reload state: {e}",
1252+
file=sys.stderr,
1253+
)
1254+
print("Falling back to re-evaluating prompt", file=sys.stderr)
12471255
elif self.verbose:
12481256
print(
12491257
f"Llama._create_completion: not reloading from cache, cache prefix len {cache_prefix_len} < eval prefix len {eval_prefix_len}",

‎llama_cpp/llama_cache.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cache.py
+116Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
1+
import ctypes
12
import pickle
23
import sys
34
from abc import ABC, abstractmethod
45
from collections import OrderedDict
56
from typing import Optional, Sequence, Tuple
67

78
import diskcache
9+
import numpy as np
810
import pytrie
911

1012
import llama_cpp.llama
1113

1214
from .llama_types import *
1315

1416

17+
class StateReloadError(Exception):
18+
"""
19+
Error for when state from cache cannot be read by current model.
20+
"""
21+
22+
1523
class BaseLlamaCache(ABC):
1624
"""Base cache class for a llama.cpp model."""
1725

@@ -48,6 +56,20 @@ def __setitem__(
4856
) -> None:
4957
raise NotImplementedError
5058

59+
@classmethod
60+
def reload_from_cache_state(
61+
cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState"
62+
) -> None:
63+
"""
64+
Reload the state onto the model. Normally this is done with load_state
65+
(as state is created with the corresponding `save_state`), but for some
66+
caches may need special handling as an optimization.
67+
68+
Throws a StateReloadError if the state is not compatible with the model
69+
(for example, logits )
70+
"""
71+
model.load_state(state)
72+
5173

5274
class LlamaRAMCache(BaseLlamaCache):
5375
"""Cache for a llama.cpp model using RAM."""
@@ -223,6 +245,7 @@ def build_cache(
223245
capacity_bytes: int = 2 << 30,
224246
seed: Optional[int] = None,
225247
add_bos=True,
248+
save_logits: bool = False,
226249
) -> "LlamaStaticDiskCache":
227250
"""
228251
Using model passed in, evaluates each prompt and stores LlamaState in cache.
@@ -246,6 +269,19 @@ def build_cache(
246269
print("LlamaStaticDiskCache.build_cache: eval", file=sys.stderr)
247270
model.eval(eval_toks)
248271
state = model.save_state()
272+
273+
if not save_logits:
274+
if (
275+
model.context_params.logits_all
276+
or model.draft_model is not None
277+
or model.context_params.embeddings
278+
):
279+
# Erroring instead of falling back to just saving with scores
280+
raise ValueError(
281+
"Cannot save state without logits - model requires logits to sample."
282+
)
283+
state.scores = None
284+
249285
cache._private_setitem(toks, state) # pylint: disable=protected-access
250286

251287
# Set up Trie for efficient prefix search
@@ -278,3 +314,83 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
278314
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
279315
# Should this just be a warning?
280316
raise ValueError("Cannot set items in a static cache")
317+
318+
@classmethod
319+
def reload_from_cache_state(
320+
cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState"
321+
) -> None:
322+
"""
323+
Skip reloading logits and set last logits from llama.cpp context struct
324+
as the scores for last token of prompt.
325+
"""
326+
# pylint: disable=protected-access
327+
328+
# Check if model needs logits (draft model, log probs required, etc.)
329+
model_needs_scores_to_reload = (
330+
# May be overly pessimistic if don't want embeddings for prompt tokens.
331+
model.context_params.embeddings
332+
or model.context_params.logits_all
333+
# Same: is this really a hard requirement? We need token IDs from
334+
# draft model and all the logits from base model to do verification
335+
# of candidate tokens, but not for prompt tokens.
336+
or model.draft_model is not None
337+
)
338+
339+
if model_needs_scores_to_reload:
340+
if state.scores is None:
341+
raise StateReloadError(
342+
"Model requires logits to be reloaded, but static cache does not store logits"
343+
)
344+
else:
345+
model.load_state(state)
346+
return
347+
348+
# Case where don't need logits from numpy and can just get last-token
349+
# logits from llama.cpp struct
350+
model.n_tokens = state.n_tokens
351+
model.input_ids = state.input_ids.copy()
352+
model.scores[:] = 0.0
353+
354+
state_size = state.llama_state_size
355+
356+
try:
357+
llama_state_array_type = ctypes.c_uint8 * state_size
358+
# Have to do from_buffer_copy since LlamaState.llama_state is
359+
# non-mutable bytes, not mutable bytearray.
360+
llama_state = llama_state_array_type.from_buffer_copy(state.llama_state)
361+
reloaded_state_size = llama_cpp.llama_set_state_data(
362+
model._ctx.ctx, llama_state
363+
)
364+
365+
if reloaded_state_size != state_size:
366+
raise StateReloadError(
367+
"Failed to set llama state data - reloaded state size "
368+
f"{reloaded_state_size} does not match original size {state_size}"
369+
)
370+
371+
# cffi dtype, compatible w/ numpy through ducktyping :scared:
372+
dtype = llama_cpp.llama_cpp.llama_get_logits_ith.restype._type_
373+
374+
# If model scores dtype doesn't match dtype from sig, then can't
375+
# copy it.
376+
if model.scores.dtype != dtype:
377+
raise StateReloadError(
378+
f"Expected scores to be {dtype} but got "
379+
f"{model.scores.dtype} - are you running this in the future? Or the past?"
380+
)
381+
382+
# Will have a ValueError for null pointers
383+
last_position_logits = np.array(
384+
ctypes.cast(
385+
model._ctx.get_logits_ith(-1),
386+
ctypes.POINTER(dtype * model.n_vocab()),
387+
).contents,
388+
# Otherwise will be a view into C array on llama.cpp context
389+
copy=True,
390+
dtype=dtype,
391+
)
392+
393+
model._scores[-1, :] = last_position_logits
394+
395+
except ValueError as e:
396+
raise StateReloadError from e

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.