1
+ import ctypes
2
+
1
3
import pytest
4
+
2
5
import llama_cpp
3
6
4
7
MODEL = "./vendor/llama.cpp/models/ggml-vocab-llama.gguf"
@@ -36,19 +39,20 @@ def test_llama_cpp_tokenization():
36
39
37
40
38
41
def test_llama_patch (monkeypatch ):
39
- llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True )
42
+ n_ctx = 128
43
+ llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True , n_ctx = n_ctx )
40
44
n_vocab = llama_cpp .llama_n_vocab (llama ._model .model )
45
+ assert n_vocab == 32000
41
46
42
47
## Set up mock function
43
- def mock_eval (* args , ** kwargs ):
48
+ def mock_decode (* args , ** kwargs ):
44
49
return 0
45
50
46
51
def mock_get_logits (* args , ** kwargs ):
47
- return (llama_cpp .c_float * n_vocab )(
48
- * [llama_cpp .c_float (0 ) for _ in range (n_vocab )]
49
- )
52
+ size = n_vocab * n_ctx
53
+ return (llama_cpp .c_float * size )()
50
54
51
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_eval )
55
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
52
56
monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
53
57
54
58
output_text = " jumps over the lazy dog."
@@ -126,19 +130,19 @@ def test_llama_pickle():
126
130
127
131
128
132
def test_utf8 (monkeypatch ):
129
- llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True )
133
+ n_ctx = 512
134
+ llama = llama_cpp .Llama (model_path = MODEL , vocab_only = True , n_ctx = n_ctx , logits_all = True )
130
135
n_vocab = llama .n_vocab ()
131
136
132
137
## Set up mock function
133
- def mock_eval (* args , ** kwargs ):
138
+ def mock_decode (* args , ** kwargs ):
134
139
return 0
135
140
136
141
def mock_get_logits (* args , ** kwargs ):
137
- return (llama_cpp .c_float * n_vocab )(
138
- * [llama_cpp .c_float (0 ) for _ in range (n_vocab )]
139
- )
142
+ size = n_vocab * n_ctx
143
+ return (llama_cpp .c_float * size )()
140
144
141
- monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_eval )
145
+ monkeypatch .setattr ("llama_cpp.llama_cpp.llama_decode" , mock_decode )
142
146
monkeypatch .setattr ("llama_cpp.llama_cpp.llama_get_logits" , mock_get_logits )
143
147
144
148
output_text = "😀"
0 commit comments