1
+ # pylint: disable=redefined-outer-name,missing-function-docstring,missing-module-docstring
2
+ import ctypes
1
3
import os
4
+ from random import seed
2
5
import tempfile
3
6
7
+ import numpy as np
8
+ import numpy .typing as npt
4
9
import pytest
5
10
6
11
from llama_cpp .llama import Llama , LlamaState
7
12
from llama_cpp .llama_cache import LlamaStaticDiskCache , StateReloadError
8
13
9
14
15
+ def _get_logits (model : Llama ) -> npt .NDArray :
16
+ """
17
+ Helper method to get logits and put into correct shape.
18
+
19
+ (Model returns the non-zero logits in batch, which may be different
20
+ depending on whether or not `logits_all` is True).
21
+
22
+ Makes a copy so that not limited by the lifetime of the CtypesArray.
23
+ """
24
+ # pylint: disable=protected-access
25
+ logits : ctypes .Array [ctypes .c_float ] = model ._ctx .get_logits_ith (- 1 )
26
+
27
+ # Return None if falsy (NULL)
28
+ if not logits :
29
+ return None
30
+
31
+ num_rows = 1
32
+ num_cols = model .n_vocab ()
33
+
34
+ logits_np : npt .NDArray = np .ctypeslib .as_array (
35
+ logits , shape = (num_rows , num_cols )
36
+ ).copy ()
37
+
38
+ return logits_np
39
+
40
+
10
41
# Have to be careful to reset to good state when testing, but don't want to
11
42
# recreate model each time.
12
- @pytest .fixture (scope = "module" )
13
- def small_model ():
43
+ def model_factory (** kwargs ) -> Llama :
14
44
model_filename = os .getenv ("LLAMA_TEST_MODEL" )
15
45
if not model_filename :
16
46
pytest .skip ("LLAMA_TEST_MODEL environment variable is not set" )
17
47
return
18
48
19
49
model_filename = os .path .expanduser (model_filename )
20
50
21
- test_model = Llama (
22
- model_filename ,
51
+ default_args = dict (
23
52
n_ctx = 2_048 ,
24
53
n_gpu_layers = 0 ,
25
54
offload_kqv = False ,
26
55
n_batch = 512 ,
27
56
embedding = False ,
57
+ # Warning - since now uses llama.cpp sampler, no longer generates
58
+ # logits for each generated token unless this is True.
59
+ logits_all = False ,
28
60
verbose = False ,
29
61
)
30
62
31
- system_prompt = r"""
63
+ default_args .update (kwargs )
64
+
65
+ test_model = Llama (model_filename , ** default_args )
66
+
67
+ return test_model
68
+
69
+
70
+ @pytest .fixture (scope = "module" )
71
+ def system_prompt () -> str :
72
+ return r"""
32
73
You are an advanced intelligence "Hal" aboard a spaceship. You are required to
33
74
act as the primary interface between the ship and its crew. You can:
34
75
* Provide information on the current status of the ship
@@ -43,27 +84,74 @@ def small_model():
43
84
* Oxygen levels: normal
44
85
""" .strip ()
45
86
46
- user_prompt = "Hal, please open the airlocks."
87
+
88
+ @pytest .fixture (scope = "module" )
89
+ def user_prompt () -> str :
90
+ return "Hal, please open the airlocks."
91
+
92
+
93
+ @pytest .fixture (scope = "module" )
94
+ def small_model (system_prompt : str , user_prompt : str ):
95
+ """
96
+ Create model and and run prompt through it to make sure has logits for last
97
+ token generation (internally).
98
+
99
+ Logits on numpy array will be all zeros since `logits_all` is False.
100
+ """
101
+ model = model_factory ()
47
102
48
103
# Ingest prompt and create completion so that will have some state.
49
104
# Last token of prompt + all tokens of generated completion will have
50
105
# non-zero logits.
51
- _ = test_model .create_chat_completion (
106
+ _ = model .create_chat_completion (
52
107
[
53
- {"role" : "system" , "text " : system_prompt },
54
- {"role" : "user" , "text " : user_prompt },
108
+ {"role" : "system" , "content " : system_prompt },
109
+ {"role" : "user" , "content " : user_prompt },
55
110
],
56
111
seed = 1234 ,
57
112
)
58
113
59
- assert test_model .n_tokens > 0
114
+ assert model .n_tokens > 0
60
115
61
- # Have at least some scores, and last entry is non-zero
62
- assert ~ ( test_model . scores == 0 ). all ( )
63
- # pylint: disable=protected-access
64
- assert ( test_model . _scores [ - 1 , :] ! = 0.0 ).all ()
116
+ # Have logits for last token
117
+ logits_np = _get_logits ( model )
118
+ assert logits_np . shape == ( 1 , model . n_vocab ())
119
+ assert ~ ( logits_np = = 0.0 ).all ()
65
120
66
- return test_model
121
+ assert (model .scores == 0.0 ).all ()
122
+
123
+ return model
124
+
125
+
126
+ @pytest .fixture (scope = "module" )
127
+ def small_model_with_logits (system_prompt : str , user_prompt : str ) -> Llama :
128
+ """
129
+ Create model with logits_all=True, needed for testing building/reloading cache
130
+ when Python-land logits are needed.
131
+ """
132
+ model = model_factory (logits_all = True )
133
+
134
+ # Ingest prompt and create completion so that will have some state.
135
+ # Last token of prompt + all tokens of generated completion will have
136
+ # non-zero logits.
137
+ _ = model .create_chat_completion (
138
+ [
139
+ {"role" : "system" , "content" : system_prompt },
140
+ {"role" : "user" , "content" : user_prompt },
141
+ ],
142
+ seed = 1234 ,
143
+ )
144
+
145
+ assert model .n_tokens > 0
146
+
147
+ # Have logits for last token
148
+ logits_np = _get_logits (model )
149
+ assert logits_np .shape == (1 , model .n_vocab ())
150
+ assert ~ (logits_np == 0.0 ).all ()
151
+
152
+ assert ~ (model .scores == 0.0 ).all ()
153
+
154
+ return model
67
155
68
156
69
157
@pytest .fixture (scope = "module" )
@@ -76,20 +164,23 @@ def llama_state(small_model) -> LlamaState:
76
164
77
165
def test_reload_from_cache_state_success (small_model , llama_state : LlamaState ):
78
166
current_state = small_model .save_state ()
79
- old_score = small_model . scores . copy ( )
167
+ old_logits = _get_logits ( small_model )
80
168
81
- LlamaStaticDiskCache . reload_from_cache_state ( small_model , llama_state )
82
- new_state = small_model . save_state ()
83
- new_score = small_model . scores . copy ( )
169
+ # Create blank model
170
+ new_model = model_factory ()
171
+ LlamaStaticDiskCache . reload_from_cache_state ( new_model , llama_state )
84
172
85
- assert (current_state .input_ids == new_state .input_ids ).all ()
173
+ assert (current_state .input_ids == new_model .input_ids ).all ()
174
+
175
+ assert current_state .n_tokens == new_model .n_tokens
176
+
177
+ # pylint: disable=protected-access
178
+ assert current_state .seed == new_model ._seed
86
179
87
- assert current_state . n_tokens == new_state . n_tokens
180
+ new_logits = _get_logits ( new_model )
88
181
89
182
# Logits for last token should match, others may not if n_batch < n_tokens
90
- assert (
91
- old_score [small_model .n_tokens - 1 , :] == new_score [small_model .n_tokens - 1 , :]
92
- ).all ()
183
+ assert (new_logits == old_logits ).all ()
93
184
94
185
95
186
def test_reload_from_cache_state_state_reload_error (small_model , llama_state ):
@@ -147,53 +238,62 @@ def test_disk_cache_e2e(small_model: Llama):
147
238
assert ~ (state2 .input_ids == 0 ).all ()
148
239
assert (state2 .input_ids == state .input_ids ).all ()
149
240
150
- last_logits = small_model . scores [ small_model . n_tokens - 1 , :]
241
+ last_logits = _get_logits ( small_model )
151
242
152
243
LlamaStaticDiskCache .reload_from_cache_state (small_model , state )
153
244
154
- last_logits2 = small_model . scores [ small_model . n_tokens - 1 , :]
245
+ last_logits2 = _get_logits ( small_model )
155
246
156
247
assert (last_logits == last_logits2 ).all ()
157
248
158
249
159
250
def test_cache_save_reload_scores_when_needed (
160
- small_model : Llama ,
251
+ small_model_with_logits : Llama ,
161
252
):
162
253
"""
163
254
When model requires it, can reload from state with scores.
164
255
"""
256
+ model_state_before_reload = small_model_with_logits .save_state ()
257
+
165
258
test_prompt = "this is a test prompt"
166
259
with tempfile .TemporaryDirectory () as cache_dir :
167
260
cache = LlamaStaticDiskCache .build_cache (
168
261
cache_dir = cache_dir ,
169
262
prompts = [test_prompt ],
170
- model = small_model ,
263
+ model = small_model_with_logits ,
171
264
capacity_bytes = 2 << 30 ,
172
265
add_bos = True ,
173
266
seed = 1234 ,
174
267
save_logits = True ,
175
268
)
176
269
177
- llama_state = small_model .save_state ()
270
+ llama_state = small_model_with_logits .save_state ()
178
271
cur_scores = llama_state .scores .copy ()
179
272
assert ~ (cur_scores == 0.0 ).all ()
273
+ assert llama_state .n_tokens > 0
180
274
181
275
try :
182
- small_model .context_params .logits_all = True
183
276
state_from_cache = cache [
184
277
tuple (llama_state .input_ids [: llama_state .n_tokens ].tolist ())
185
278
]
186
279
assert state_from_cache .scores is not None , "Scores should be saved."
187
- LlamaStaticDiskCache .reload_from_cache_state (small_model , state_from_cache )
280
+ LlamaStaticDiskCache .reload_from_cache_state (
281
+ small_model_with_logits , state_from_cache
282
+ )
283
+
284
+ assert state_from_cache .n_tokens == small_model_with_logits .n_tokens
285
+ # pylint: disable=protected-access
286
+ assert state_from_cache .seed == small_model_with_logits ._seed
287
+
188
288
# Do I have to limit these to n_tokens?
189
289
assert (state_from_cache .input_ids == llama_state .input_ids ).all ()
190
290
assert (
191
- cur_scores == small_model .scores [: small_model .n_tokens ]
291
+ cur_scores
292
+ == small_model_with_logits .scores [: small_model_with_logits .n_tokens ]
192
293
).all (), "Reloaded scores should match"
193
294
finally :
194
- small_model .scores [:] = 0.0
195
- small_model .context_params .logits_all = False
196
- small_model .reset ()
295
+ # Reset if re-used in later tests
296
+ small_model_with_logits .load_state (model_state_before_reload )
197
297
198
298
199
299
def test_cache_reload_errors_when_requires_scores_and_state_doesnt_have_it (
0 commit comments