Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e5cccf4

Browse filesBrowse files
committed
Fixup reloading
- Don't add additional test in load_state for size, keep doing what upstream is doing there. - Don't reload numpy logits (scores) at all if not required - Comment out block for setting last logits from the lllama.cpp data
1 parent 15bf3e8 commit e5cccf4
Copy full SHA for e5cccf4

File tree

Expand file treeCollapse file tree

2 files changed

+39
-13
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+39
-13
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+6-3Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2226,9 +2226,12 @@ def load_state(self, state: LlamaState) -> None:
22262226
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
22272227

22282228
# Use non-deprecated llama_state_set_data over llama_set_state_data
2229-
if (ctypes.sizeof(llama_state) != state_size) or llama_cpp.llama_state_set_data(
2230-
self._ctx.ctx, llama_state, ctypes.sizeof(llama_state)
2231-
) != state_size:
2229+
if (
2230+
llama_cpp.llama_state_set_data(
2231+
self._ctx.ctx, llama_state, ctypes.sizeof(llama_state)
2232+
)
2233+
!= state_size
2234+
):
22322235
raise RuntimeError("Failed to set llama state data")
22332236

22342237
def n_ctx(self) -> int:

‎llama_cpp/llama_cache.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cache.py
+33-10Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import ctypes
21
import pickle
2+
import ctypes
33
import sys
44
from abc import ABC, abstractmethod
55
from collections import OrderedDict
66
from typing import Optional, Sequence, Tuple
77

88
import diskcache
9-
import numpy as np
109
import pytrie
1110

1211
import llama_cpp.llama
@@ -77,9 +76,9 @@ class LlamaRAMCache(BaseLlamaCache):
7776
def __init__(self, capacity_bytes: int = (2 << 30)):
7877
super().__init__(capacity_bytes)
7978
self.capacity_bytes = capacity_bytes
80-
self.cache_state: OrderedDict[
81-
Tuple[int, ...], "llama_cpp.llama.LlamaState"
82-
] = OrderedDict()
79+
self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = (
80+
OrderedDict()
81+
)
8382

8483
@property
8584
def cache_size(self):
@@ -320,8 +319,8 @@ def reload_from_cache_state(
320319
cls, model: "llama_cpp.llama.Llama", state: "llama_cpp.llama.LlamaState"
321320
) -> None:
322321
"""
323-
Skip reloading logits and set last logits from llama.cpp context struct
324-
as the scores for last token of prompt.
322+
Skip reloading logits (zero-out instead) unless `logits_all` or
323+
otherwise needed.
325324
"""
326325
# pylint: disable=protected-access
327326

@@ -349,17 +348,40 @@ def reload_from_cache_state(
349348
# logits from llama.cpp struct
350349
model.n_tokens = state.n_tokens
351350
model.input_ids = state.input_ids.copy()
352-
model.scores[:] = 0.0
351+
model._seed = state.seed
352+
353+
if model.scores.shape[0] < state.n_tokens:
354+
raise StateReloadError(
355+
f"Model context / batch size {model.scores.shape[0]} not large "
356+
f"enough for saved state tokens {state.n_tokens}."
357+
)
353358

354359
state_size = state.llama_state_size
355360

361+
LlamaStateArrayType = ctypes.c_uint8 * state_size
362+
llama_state = LlamaStateArrayType.from_buffer_copy(state.llama_state)
363+
364+
# Use non-deprecated llama_state_set_data over llama_set_state_data
365+
if (
366+
bytes_set := llama_cpp.llama_state_set_data(
367+
model._ctx.ctx, llama_state, ctypes.sizeof(llama_state)
368+
),
369+
) != state_size:
370+
raise RuntimeError(
371+
"Failed to set llama state data - mismatch between bytes set "
372+
f"{bytes_set} and state size {state_size}"
373+
)
374+
375+
# No longer need to reload scores, since now use llama.cpp sampler.
376+
# pylint: disable=pointless-string-statement
377+
"""
356378
try:
357379
llama_state_array_type = ctypes.c_uint8 * state_size
358380
# Have to do from_buffer_copy since LlamaState.llama_state is
359381
# non-mutable bytes, not mutable bytearray.
360382
llama_state = llama_state_array_type.from_buffer_copy(state.llama_state)
361-
reloaded_state_size = llama_cpp.llama_set_state_data(
362-
model._ctx.ctx, llama_state
383+
reloaded_state_size = llama_cpp.llama_state_set_data(
384+
model._ctx.ctx, llama_state, ctypes.sizeof(llama_state)
363385
)
364386
365387
if reloaded_state_size != state_size:
@@ -394,3 +416,4 @@ def reload_from_cache_state(
394416
395417
except ValueError as e:
396418
raise StateReloadError from e
419+
"""

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.