Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3dc21b2

Browse filesBrowse files
committed
tests: Improve llama.cpp mock
1 parent 63fe137 commit 3dc21b2
Copy full SHA for 3dc21b2

File tree

Expand file treeCollapse file tree

1 file changed

+92
-51
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+92
-51
lines changed

‎tests/test_llama.py

Copy file name to clipboardExpand all lines: tests/test_llama.py
+92-51Lines changed: 92 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -37,77 +37,106 @@ def test_llama_cpp_tokenization():
3737
assert tokens[-1] == llama.token_eos()
3838
assert tokens == [1, 15043, 2787, 2]
3939

40-
41-
def test_llama_patch(monkeypatch):
40+
text = b""
41+
tokens = llama.tokenize(text, add_bos=True, special=True)
42+
assert tokens[-1] != llama.token_eos()
43+
assert tokens == [llama.token_bos()]
44+
assert text == llama.detokenize(tokens)
45+
46+
47+
@pytest.fixture
48+
def mock_llama(monkeypatch):
49+
def setup_mock(llama: llama_cpp.Llama, output_text: str):
50+
llama.reset()
51+
n_vocab = llama.n_vocab()
52+
output_tokens = llama.tokenize(
53+
output_text.encode("utf-8"), add_bos=True, special=True
54+
)
55+
n = 0
56+
last_n_tokens = 0
57+
58+
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
59+
nonlocal n
60+
nonlocal last_n_tokens
61+
# Test some basic invariants of this mocking technique
62+
assert ctx == llama._ctx.ctx
63+
assert llama.n_tokens == n
64+
assert batch.n_tokens > 0
65+
n += batch.n_tokens
66+
last_n_tokens = batch.n_tokens
67+
return 0
68+
69+
def mock_get_logits(*args, **kwargs):
70+
nonlocal last_n_tokens
71+
size = n_vocab * last_n_tokens
72+
return (llama_cpp.c_float * size)()
73+
74+
def mock_sample(*args, **kwargs):
75+
nonlocal n
76+
if n < len(output_tokens):
77+
return output_tokens[n]
78+
else:
79+
return llama.token_eos()
80+
81+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
82+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
83+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
84+
85+
return setup_mock
86+
87+
88+
def test_llama_patch(mock_llama):
4289
n_ctx = 128
4390
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx)
4491
n_vocab = llama_cpp.llama_n_vocab(llama._model.model)
4592
assert n_vocab == 32000
4693

47-
## Set up mock function
48-
def mock_decode(*args, **kwargs):
49-
return 0
50-
51-
def mock_get_logits(*args, **kwargs):
52-
size = n_vocab * n_ctx
53-
return (llama_cpp.c_float * size)()
54-
55-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
56-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
57-
5894
text = "The quick brown fox"
59-
text_tokens = llama.tokenize(text.encode("utf-8"), add_bos=True, special=True)
6095
output_text = " jumps over the lazy dog."
61-
all_text_tokens = llama.tokenize((text + output_text).encode("utf-8"), add_bos=True, special=True)
62-
output_tokens = all_text_tokens[len(text_tokens):]
63-
token_eos = llama.token_eos()
64-
n = 0
65-
66-
def mock_sample(*args, **kwargs):
67-
nonlocal n
68-
if n < len(output_tokens):
69-
n += 1
70-
return output_tokens[n - 1]
71-
else:
72-
return token_eos
73-
74-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
96+
all_text = text + output_text
7597

98+
## Test basic completion from bos until eos
99+
mock_llama(llama, all_text)
100+
completion = llama.create_completion("", max_tokens=36)
101+
assert completion["choices"][0]["text"] == all_text
102+
assert completion["choices"][0]["finish_reason"] == "stop"
76103

77104
## Test basic completion until eos
78-
n = 0 # reset
105+
mock_llama(llama, all_text)
79106
completion = llama.create_completion(text, max_tokens=20)
80107
assert completion["choices"][0]["text"] == output_text
81108
assert completion["choices"][0]["finish_reason"] == "stop"
82109

83110
## Test streaming completion until eos
84-
n = 0 # reset
111+
mock_llama(llama, all_text)
85112
chunks = list(llama.create_completion(text, max_tokens=20, stream=True))
86113
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text
87114
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
88115

89116
## Test basic completion until stop sequence
90-
n = 0 # reset
117+
mock_llama(llama, all_text)
91118
completion = llama.create_completion(text, max_tokens=20, stop=["lazy"])
92119
assert completion["choices"][0]["text"] == " jumps over the "
93120
assert completion["choices"][0]["finish_reason"] == "stop"
94121

95122
## Test streaming completion until stop sequence
96-
n = 0 # reset
97-
chunks = list(llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"]))
123+
mock_llama(llama, all_text)
124+
chunks = list(
125+
llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"])
126+
)
98127
assert (
99128
"".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the "
100129
)
101130
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
102131

103132
## Test basic completion until length
104-
n = 0 # reset
133+
mock_llama(llama, all_text)
105134
completion = llama.create_completion(text, max_tokens=2)
106135
assert completion["choices"][0]["text"] == " jumps"
107136
assert completion["choices"][0]["finish_reason"] == "length"
108137

109138
## Test streaming completion until length
110-
n = 0 # reset
139+
mock_llama(llama, all_text)
111140
chunks = list(llama.create_completion(text, max_tokens=2, stream=True))
112141
assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps"
113142
assert chunks[-1]["choices"][0]["finish_reason"] == "length"
@@ -131,44 +160,55 @@ def test_llama_pickle():
131160
assert llama.detokenize(llama.tokenize(text)) == text
132161

133162

134-
def test_utf8(monkeypatch):
135-
n_ctx = 512
136-
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx, logits_all=True)
163+
def test_utf8(mock_llama, monkeypatch):
164+
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True)
165+
n_ctx = llama.n_ctx()
137166
n_vocab = llama.n_vocab()
138167

168+
output_text = "😀"
169+
output_tokens = llama.tokenize(
170+
output_text.encode("utf-8"), add_bos=True, special=True
171+
)
172+
token_eos = llama.token_eos()
173+
n = 0
174+
175+
def reset():
176+
nonlocal n
177+
llama.reset()
178+
n = 0
179+
139180
## Set up mock function
140-
def mock_decode(*args, **kwargs):
181+
def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
182+
nonlocal n
183+
assert batch.n_tokens > 0
184+
assert llama.n_tokens == n
185+
n += batch.n_tokens
141186
return 0
142187

143188
def mock_get_logits(*args, **kwargs):
144189
size = n_vocab * n_ctx
145190
return (llama_cpp.c_float * size)()
146191

147-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
148-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
149-
150-
output_text = "😀"
151-
output_tokens = llama.tokenize(output_text.encode("utf-8"))
152-
token_eos = llama.token_eos()
153-
n = 0
154-
155192
def mock_sample(*args, **kwargs):
156193
nonlocal n
157-
if n < len(output_tokens):
158-
n += 1
194+
if n <= len(output_tokens):
159195
return output_tokens[n - 1]
160196
else:
161197
return token_eos
162198

199+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
200+
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
163201
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
164202

165203
## Test basic completion with utf8 multibyte
166-
n = 0 # reset
204+
# mock_llama(llama, output_text)
205+
reset()
167206
completion = llama.create_completion("", max_tokens=4)
168207
assert completion["choices"][0]["text"] == output_text
169208

170209
## Test basic completion with incomplete utf8 multibyte
171-
n = 0 # reset
210+
# mock_llama(llama, output_text)
211+
reset()
172212
completion = llama.create_completion("", max_tokens=1)
173213
assert completion["choices"][0]["text"] == ""
174214

@@ -196,5 +236,6 @@ def test_llama_server():
196236
],
197237
}
198238

239+
199240
def test_llama_cpp_version():
200241
assert llama_cpp.__version__

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.