@@ -37,77 +37,106 @@ def test_llama_cpp_tokenization():
37
37
assert tokens [- 1 ] == llama .token_eos ()
38
38
assert tokens == [1 , 15043 , 2787 , 2 ]
39
39
40
-
41
- def test_llama_patch (monkeypatch ):
40
+ text = b""
41
+ tokens = llama .tokenize (text , add_bos = True , special = True )
42
+ assert tokens [- 1 ] != llama .token_eos ()
43
+ assert tokens == [llama .token_bos ()]
44
+ assert text == llama .detokenize (tokens )
45
+
46
+
47
+ @pytest .fixture
48
+ def mock_llama (monkeypatch ):
49
+ def setup_mock (llama : llama_cpp .Llama , output_text : str ):
50
+ llama .reset ()
51
+ n_vocab = llama .n_vocab ()
52
+ output_tokens = llama .tokenize (
53
+ output_text .encode ("utf-8" ), add_bos = True , special = True
54
+ )
55
+ n = 0
56
+ last_n_tokens = 0
57
+
58
+ def mock_decode (ctx : llama_cpp .llama_context_p , batch : llama_cpp .llama_batch ):
59
+ nonlocal n
60
+ nonlocal last_n_tokens
61
+ # Test some basic invariants of this mocking technique
62
+ assert ctx == llama ._ctx .ctx
63
+ assert llama .n_tokens == n
64
+ assert batch .n_tokens > 0
65
+ n += batch .n_tokens
66
+ last_n_tokens = batch .n_tokens
67
+ return 0
68
+
69
+ def mock_get_logits (* args , ** kwargs ):
70
+ nonlocal last_n_tokens
71
+ size = n_vocab * last_n_tokens
72
+ return (llama_cpp .c_float * size )()
73
+
74
+ def mock_sample (* args , ** kwargs ):
75
+ nonlocal n
76
+ if n < len (output_tokens ):
77
+ return output_tokens [n ]
78
+ else :
79
+ return llama .token_eos ()
80
+
81
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
82
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
83
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
84
+
85
+ return setup_mock
86
+
87
+
88
+ def test_llama_patch (mock_llama ):
42
89
n_ctx = 128
43
90
llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True , n_ctx = n_ctx )
44
91
n_vocab = llama_cpp .llama_n_vocab (llama ._model .model )
45
92
assert n_vocab == 32000
46
93
47
- ## Set up mock function
48
- def mock_decode (* args , ** kwargs ):
49
- return 0
50
-
51
- def mock_get_logits (* args , ** kwargs ):
52
- size = n_vocab * n_ctx
53
- return (llama_cpp .c_float * size )()
54
-
55
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
56
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
57
-
58
94
text = "The quick brown fox"
59
- text_tokens = llama .tokenize (text .encode ("utf-8" ), add_bos = True , special = True )
60
95
output_text = " jumps over the lazy dog."
61
- all_text_tokens = llama .tokenize ((text + output_text ).encode ("utf-8" ), add_bos = True , special = True )
62
- output_tokens = all_text_tokens [len (text_tokens ):]
63
- token_eos = llama .token_eos ()
64
- n = 0
65
-
66
- def mock_sample (* args , ** kwargs ):
67
- nonlocal n
68
- if n < len (output_tokens ):
69
- n += 1
70
- return output_tokens [n - 1 ]
71
- else :
72
- return token_eos
73
-
74
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
96
+ all_text = text + output_text
75
97
98
+ ## Test basic completion from bos until eos
99
+ mock_llama (llama , all_text )
100
+ completion = llama .create_completion ("" , max_tokens = 36 )
101
+ assert completion ["choices" ][0 ]["text" ] == all_text
102
+ assert completion ["choices" ][0 ]["finish_reason" ] == "stop"
76
103
77
104
## Test basic completion until eos
78
- n = 0 # reset
105
+ mock_llama ( llama , all_text )
79
106
completion = llama .create_completion (text , max_tokens = 20 )
80
107
assert completion ["choices" ][0 ]["text" ] == output_text
81
108
assert completion ["choices" ][0 ]["finish_reason" ] == "stop"
82
109
83
110
## Test streaming completion until eos
84
- n = 0 # reset
111
+ mock_llama ( llama , all_text )
85
112
chunks = list (llama .create_completion (text , max_tokens = 20 , stream = True ))
86
113
assert "" .join (chunk ["choices" ][0 ]["text" ] for chunk in chunks ) == output_text
87
114
assert chunks [- 1 ]["choices" ][0 ]["finish_reason" ] == "stop"
88
115
89
116
## Test basic completion until stop sequence
90
- n = 0 # reset
117
+ mock_llama ( llama , all_text )
91
118
completion = llama .create_completion (text , max_tokens = 20 , stop = ["lazy" ])
92
119
assert completion ["choices" ][0 ]["text" ] == " jumps over the "
93
120
assert completion ["choices" ][0 ]["finish_reason" ] == "stop"
94
121
95
122
## Test streaming completion until stop sequence
96
- n = 0 # reset
97
- chunks = list (llama .create_completion (text , max_tokens = 20 , stream = True , stop = ["lazy" ]))
123
+ mock_llama (llama , all_text )
124
+ chunks = list (
125
+ llama .create_completion (text , max_tokens = 20 , stream = True , stop = ["lazy" ])
126
+ )
98
127
assert (
99
128
"" .join (chunk ["choices" ][0 ]["text" ] for chunk in chunks ) == " jumps over the "
100
129
)
101
130
assert chunks [- 1 ]["choices" ][0 ]["finish_reason" ] == "stop"
102
131
103
132
## Test basic completion until length
104
- n = 0 # reset
133
+ mock_llama ( llama , all_text )
105
134
completion = llama .create_completion (text , max_tokens = 2 )
106
135
assert completion ["choices" ][0 ]["text" ] == " jumps"
107
136
assert completion ["choices" ][0 ]["finish_reason" ] == "length"
108
137
109
138
## Test streaming completion until length
110
- n = 0 # reset
139
+ mock_llama ( llama , all_text )
111
140
chunks = list (llama .create_completion (text , max_tokens = 2 , stream = True ))
112
141
assert "" .join (chunk ["choices" ][0 ]["text" ] for chunk in chunks ) == " jumps"
113
142
assert chunks [- 1 ]["choices" ][0 ]["finish_reason" ] == "length"
@@ -131,44 +160,55 @@ def test_llama_pickle():
131
160
assert llama .detokenize (llama .tokenize (text )) == text
132
161
133
162
134
- def test_utf8 (monkeypatch ):
135
- n_ctx = 512
136
- llama = llama_cpp . Llama ( model_path = MODEL , vocab_only = True , n_ctx = n_ctx , logits_all = True )
163
+ def test_utf8 (mock_llama , monkeypatch ):
164
+ llama = llama_cpp . Llama ( model_path = MODEL , vocab_only = True , logits_all = True )
165
+ n_ctx = llama . n_ctx ( )
137
166
n_vocab = llama .n_vocab ()
138
167
168
+ output_text = "😀"
169
+ output_tokens = llama .tokenize (
170
+ output_text .encode ("utf-8" ), add_bos = True , special = True
171
+ )
172
+ token_eos = llama .token_eos ()
173
+ n = 0
174
+
175
+ def reset ():
176
+ nonlocal n
177
+ llama .reset ()
178
+ n = 0
179
+
139
180
## Set up mock function
140
- def mock_decode (* args , ** kwargs ):
181
+ def mock_decode (ctx : llama_cpp .llama_context_p , batch : llama_cpp .llama_batch ):
182
+ nonlocal n
183
+ assert batch .n_tokens > 0
184
+ assert llama .n_tokens == n
185
+ n += batch .n_tokens
141
186
return 0
142
187
143
188
def mock_get_logits (* args , ** kwargs ):
144
189
size = n_vocab * n_ctx
145
190
return (llama_cpp .c_float * size )()
146
191
147
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
148
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
149
-
150
- output_text = "😀"
151
- output_tokens = llama .tokenize (output_text .encode ("utf-8" ))
152
- token_eos = llama .token_eos ()
153
- n = 0
154
-
155
192
def mock_sample (* args , ** kwargs ):
156
193
nonlocal n
157
- if n < len (output_tokens ):
158
- n += 1
194
+ if n <= len (output_tokens ):
159
195
return output_tokens [n - 1 ]
160
196
else :
161
197
return token_eos
162
198
199
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
200
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
163
201
monkeypatch .setattr ("llama_cpp.llama_cpp.llama_sample_token" , mock_sample )
164
202
165
203
## Test basic completion with utf8 multibyte
166
- n = 0 # reset
204
+ # mock_llama(llama, output_text)
205
+ reset ()
167
206
completion = llama .create_completion ("" , max_tokens = 4 )
168
207
assert completion ["choices" ][0 ]["text" ] == output_text
169
208
170
209
## Test basic completion with incomplete utf8 multibyte
171
- n = 0 # reset
210
+ # mock_llama(llama, output_text)
211
+ reset ()
172
212
completion = llama .create_completion ("" , max_tokens = 1 )
173
213
assert completion ["choices" ][0 ]["text" ] == ""
174
214
@@ -196,5 +236,6 @@ def test_llama_server():
196
236
],
197
237
}
198
238
239
+
199
240
def test_llama_cpp_version ():
200
241
assert llama_cpp .__version__
0 commit comments