Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5de50b9

Browse filesBrowse files
committed
Update tests for cache
- Helper to get logits from llama.cpp context to numpy - Model w/ and w/out logits_all - Update tests as needed now that `model.scores` not set on reload
1 parent c9bf03a commit 5de50b9
Copy full SHA for 5de50b9

File tree

Expand file treeCollapse file tree

1 file changed

+135
-35
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+135
-35
lines changed

‎tests/test_llama_cache.py

Copy file name to clipboardExpand all lines: tests/test_llama_cache.py
+135-35Lines changed: 135 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,75 @@
1+
# pylint: disable=redefined-outer-name,missing-function-docstring,missing-module-docstring
2+
import ctypes
13
import os
4+
from random import seed
25
import tempfile
36

7+
import numpy as np
8+
import numpy.typing as npt
49
import pytest
510

611
from llama_cpp.llama import Llama, LlamaState
712
from llama_cpp.llama_cache import LlamaStaticDiskCache, StateReloadError
813

914

15+
def _get_logits(model: Llama) -> npt.NDArray:
16+
"""
17+
Helper method to get logits and put into correct shape.
18+
19+
(Model returns the non-zero logits in batch, which may be different
20+
depending on whether or not `logits_all` is True).
21+
22+
Makes a copy so that not limited by the lifetime of the CtypesArray.
23+
"""
24+
# pylint: disable=protected-access
25+
logits: ctypes.Array[ctypes.c_float] = model._ctx.get_logits_ith(-1)
26+
27+
# Return None if falsy (NULL)
28+
if not logits:
29+
return None
30+
31+
num_rows = 1
32+
num_cols = model.n_vocab()
33+
34+
logits_np: npt.NDArray = np.ctypeslib.as_array(
35+
logits, shape=(num_rows, num_cols)
36+
).copy()
37+
38+
return logits_np
39+
40+
1041
# Have to be careful to reset to good state when testing, but don't want to
1142
# recreate model each time.
12-
@pytest.fixture(scope="module")
13-
def small_model():
43+
def model_factory(**kwargs) -> Llama:
1444
model_filename = os.getenv("LLAMA_TEST_MODEL")
1545
if not model_filename:
1646
pytest.skip("LLAMA_TEST_MODEL environment variable is not set")
1747
return
1848

1949
model_filename = os.path.expanduser(model_filename)
2050

21-
test_model = Llama(
22-
model_filename,
51+
default_args = dict(
2352
n_ctx=2_048,
2453
n_gpu_layers=0,
2554
offload_kqv=False,
2655
n_batch=512,
2756
embedding=False,
57+
# Warning - since now uses llama.cpp sampler, no longer generates
58+
# logits for each generated token unless this is True.
59+
logits_all=False,
2860
verbose=False,
2961
)
3062

31-
system_prompt = r"""
63+
default_args.update(kwargs)
64+
65+
test_model = Llama(model_filename, **default_args)
66+
67+
return test_model
68+
69+
70+
@pytest.fixture(scope="module")
71+
def system_prompt() -> str:
72+
return r"""
3273
You are an advanced intelligence "Hal" aboard a spaceship. You are required to
3374
act as the primary interface between the ship and its crew. You can:
3475
* Provide information on the current status of the ship
@@ -43,27 +84,74 @@ def small_model():
4384
* Oxygen levels: normal
4485
""".strip()
4586

46-
user_prompt = "Hal, please open the airlocks."
87+
88+
@pytest.fixture(scope="module")
89+
def user_prompt() -> str:
90+
return "Hal, please open the airlocks."
91+
92+
93+
@pytest.fixture(scope="module")
94+
def small_model(system_prompt: str, user_prompt: str):
95+
"""
96+
Create model and and run prompt through it to make sure has logits for last
97+
token generation (internally).
98+
99+
Logits on numpy array will be all zeros since `logits_all` is False.
100+
"""
101+
model = model_factory()
47102

48103
# Ingest prompt and create completion so that will have some state.
49104
# Last token of prompt + all tokens of generated completion will have
50105
# non-zero logits.
51-
_ = test_model.create_chat_completion(
106+
_ = model.create_chat_completion(
52107
[
53-
{"role": "system", "text": system_prompt},
54-
{"role": "user", "text": user_prompt},
108+
{"role": "system", "content": system_prompt},
109+
{"role": "user", "content": user_prompt},
55110
],
56111
seed=1234,
57112
)
58113

59-
assert test_model.n_tokens > 0
114+
assert model.n_tokens > 0
60115

61-
# Have at least some scores, and last entry is non-zero
62-
assert ~(test_model.scores == 0).all()
63-
# pylint: disable=protected-access
64-
assert (test_model._scores[-1, :] != 0.0).all()
116+
# Have logits for last token
117+
logits_np = _get_logits(model)
118+
assert logits_np.shape == (1, model.n_vocab())
119+
assert ~(logits_np == 0.0).all()
65120

66-
return test_model
121+
assert (model.scores == 0.0).all()
122+
123+
return model
124+
125+
126+
@pytest.fixture(scope="module")
127+
def small_model_with_logits(system_prompt: str, user_prompt: str) -> Llama:
128+
"""
129+
Create model with logits_all=True, needed for testing building/reloading cache
130+
when Python-land logits are needed.
131+
"""
132+
model = model_factory(logits_all=True)
133+
134+
# Ingest prompt and create completion so that will have some state.
135+
# Last token of prompt + all tokens of generated completion will have
136+
# non-zero logits.
137+
_ = model.create_chat_completion(
138+
[
139+
{"role": "system", "content": system_prompt},
140+
{"role": "user", "content": user_prompt},
141+
],
142+
seed=1234,
143+
)
144+
145+
assert model.n_tokens > 0
146+
147+
# Have logits for last token
148+
logits_np = _get_logits(model)
149+
assert logits_np.shape == (1, model.n_vocab())
150+
assert ~(logits_np == 0.0).all()
151+
152+
assert ~(model.scores == 0.0).all()
153+
154+
return model
67155

68156

69157
@pytest.fixture(scope="module")
@@ -76,20 +164,23 @@ def llama_state(small_model) -> LlamaState:
76164

77165
def test_reload_from_cache_state_success(small_model, llama_state: LlamaState):
78166
current_state = small_model.save_state()
79-
old_score = small_model.scores.copy()
167+
old_logits = _get_logits(small_model)
80168

81-
LlamaStaticDiskCache.reload_from_cache_state(small_model, llama_state)
82-
new_state = small_model.save_state()
83-
new_score = small_model.scores.copy()
169+
# Create blank model
170+
new_model = model_factory()
171+
LlamaStaticDiskCache.reload_from_cache_state(new_model, llama_state)
84172

85-
assert (current_state.input_ids == new_state.input_ids).all()
173+
assert (current_state.input_ids == new_model.input_ids).all()
174+
175+
assert current_state.n_tokens == new_model.n_tokens
176+
177+
# pylint: disable=protected-access
178+
assert current_state.seed == new_model._seed
86179

87-
assert current_state.n_tokens == new_state.n_tokens
180+
new_logits = _get_logits(new_model)
88181

89182
# Logits for last token should match, others may not if n_batch < n_tokens
90-
assert (
91-
old_score[small_model.n_tokens - 1, :] == new_score[small_model.n_tokens - 1, :]
92-
).all()
183+
assert (new_logits == old_logits).all()
93184

94185

95186
def test_reload_from_cache_state_state_reload_error(small_model, llama_state):
@@ -147,53 +238,62 @@ def test_disk_cache_e2e(small_model: Llama):
147238
assert ~(state2.input_ids == 0).all()
148239
assert (state2.input_ids == state.input_ids).all()
149240

150-
last_logits = small_model.scores[small_model.n_tokens - 1, :]
241+
last_logits = _get_logits(small_model)
151242

152243
LlamaStaticDiskCache.reload_from_cache_state(small_model, state)
153244

154-
last_logits2 = small_model.scores[small_model.n_tokens - 1, :]
245+
last_logits2 = _get_logits(small_model)
155246

156247
assert (last_logits == last_logits2).all()
157248

158249

159250
def test_cache_save_reload_scores_when_needed(
160-
small_model: Llama,
251+
small_model_with_logits: Llama,
161252
):
162253
"""
163254
When model requires it, can reload from state with scores.
164255
"""
256+
model_state_before_reload = small_model_with_logits.save_state()
257+
165258
test_prompt = "this is a test prompt"
166259
with tempfile.TemporaryDirectory() as cache_dir:
167260
cache = LlamaStaticDiskCache.build_cache(
168261
cache_dir=cache_dir,
169262
prompts=[test_prompt],
170-
model=small_model,
263+
model=small_model_with_logits,
171264
capacity_bytes=2 << 30,
172265
add_bos=True,
173266
seed=1234,
174267
save_logits=True,
175268
)
176269

177-
llama_state = small_model.save_state()
270+
llama_state = small_model_with_logits.save_state()
178271
cur_scores = llama_state.scores.copy()
179272
assert ~(cur_scores == 0.0).all()
273+
assert llama_state.n_tokens > 0
180274

181275
try:
182-
small_model.context_params.logits_all = True
183276
state_from_cache = cache[
184277
tuple(llama_state.input_ids[: llama_state.n_tokens].tolist())
185278
]
186279
assert state_from_cache.scores is not None, "Scores should be saved."
187-
LlamaStaticDiskCache.reload_from_cache_state(small_model, state_from_cache)
280+
LlamaStaticDiskCache.reload_from_cache_state(
281+
small_model_with_logits, state_from_cache
282+
)
283+
284+
assert state_from_cache.n_tokens == small_model_with_logits.n_tokens
285+
# pylint: disable=protected-access
286+
assert state_from_cache.seed == small_model_with_logits._seed
287+
188288
# Do I have to limit these to n_tokens?
189289
assert (state_from_cache.input_ids == llama_state.input_ids).all()
190290
assert (
191-
cur_scores == small_model.scores[: small_model.n_tokens]
291+
cur_scores
292+
== small_model_with_logits.scores[: small_model_with_logits.n_tokens]
192293
).all(), "Reloaded scores should match"
193294
finally:
194-
small_model.scores[:] = 0.0
195-
small_model.context_params.logits_all = False
196-
small_model.reset()
295+
# Reset if re-used in later tests
296+
small_model_with_logits.load_state(model_state_before_reload)
197297

198298

199299
def test_cache_reload_errors_when_requires_scores_and_state_doesnt_have_it(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.