1
1
import sys
2
2
import os
3
3
import ctypes
4
- from ctypes import c_int , c_float , c_char_p , c_void_p , c_bool , POINTER , Structure , Array , c_uint8 , c_size_t
4
+ from ctypes import (
5
+ c_int ,
6
+ c_float ,
7
+ c_char_p ,
8
+ c_void_p ,
9
+ c_bool ,
10
+ POINTER ,
11
+ Structure ,
12
+ Array ,
13
+ c_uint8 ,
14
+ c_size_t ,
15
+ )
5
16
import pathlib
6
17
18
+
7
19
# Load the library
8
20
def _load_shared_library (lib_base_name ):
9
21
# Determine the file extension based on the platform
@@ -22,10 +34,10 @@ def _load_shared_library(lib_base_name):
22
34
# for llamacpp) and "llama" (default name for this repo)
23
35
_lib_paths = [
24
36
_base_path / f"lib{ lib_base_name } { lib_ext } " ,
25
- _base_path / f"{ lib_base_name } { lib_ext } "
37
+ _base_path / f"{ lib_base_name } { lib_ext } " ,
26
38
]
27
39
28
- if ( "LLAMA_CPP_LIB" in os .environ ) :
40
+ if "LLAMA_CPP_LIB" in os .environ :
29
41
lib_base_name = os .environ ["LLAMA_CPP_LIB" ]
30
42
_lib = pathlib .Path (lib_base_name )
31
43
_base_path = _lib .parent .resolve ()
@@ -43,7 +55,10 @@ def _load_shared_library(lib_base_name):
43
55
except Exception as e :
44
56
raise RuntimeError (f"Failed to load shared library '{ _lib_path } ': { e } " )
45
57
46
- raise FileNotFoundError (f"Shared library with base name '{ lib_base_name } ' not found" )
58
+ raise FileNotFoundError (
59
+ f"Shared library with base name '{ lib_base_name } ' not found"
60
+ )
61
+
47
62
48
63
# Specify the base name of the shared library to load
49
64
_lib_base_name = "llama"
@@ -95,6 +110,10 @@ class llama_context_params(Structure):
95
110
96
111
llama_context_params_p = POINTER (llama_context_params )
97
112
113
+ LLAMA_FTYPE_ALL_F32 = ctypes .c_int (0 )
114
+ LLAMA_FTYPE_MOSTLY_F16 = ctypes .c_int (1 ) # except 1d tensors
115
+ LLAMA_FTYPE_MOSTLY_Q4_0 = ctypes .c_int (2 ) # except 1d tensors
116
+ LLAMA_FTYPE_MOSTLY_Q4_1 = ctypes .c_int (3 ) # except 1d tensors
98
117
99
118
# Functions
100
119
@@ -106,18 +125,23 @@ def llama_context_default_params() -> llama_context_params:
106
125
_lib .llama_context_default_params .argtypes = []
107
126
_lib .llama_context_default_params .restype = llama_context_params
108
127
128
+
109
129
def llama_mmap_supported () -> c_bool :
110
130
return _lib .llama_mmap_supported ()
111
131
132
+
112
133
_lib .llama_mmap_supported .argtypes = []
113
134
_lib .llama_mmap_supported .restype = c_bool
114
135
136
+
115
137
def llama_mlock_supported () -> c_bool :
116
138
return _lib .llama_mlock_supported ()
117
139
140
+
118
141
_lib .llama_mlock_supported .argtypes = []
119
142
_lib .llama_mlock_supported .restype = c_bool
120
143
144
+
121
145
# Various functions for loading a ggml llama model.
122
146
# Allocate (almost) all memory needed for the model.
123
147
# Return NULL on failure
@@ -142,42 +166,49 @@ def llama_free(ctx: llama_context_p):
142
166
143
167
# TODO: not great API - very likely to change
144
168
# Returns 0 on success
145
- def llama_model_quantize (
146
- fname_inp : bytes , fname_out : bytes , itype : c_int
147
- ) -> c_int :
169
+ def llama_model_quantize (fname_inp : bytes , fname_out : bytes , itype : c_int ) -> c_int :
148
170
return _lib .llama_model_quantize (fname_inp , fname_out , itype )
149
171
150
172
151
173
_lib .llama_model_quantize .argtypes = [c_char_p , c_char_p , c_int ]
152
174
_lib .llama_model_quantize .restype = c_int
153
175
176
+
154
177
# Returns the KV cache that will contain the context for the
155
178
# ongoing prediction with the model.
156
179
def llama_get_kv_cache (ctx : llama_context_p ):
157
180
return _lib .llama_get_kv_cache (ctx )
158
181
182
+
159
183
_lib .llama_get_kv_cache .argtypes = [llama_context_p ]
160
184
_lib .llama_get_kv_cache .restype = POINTER (c_uint8 )
161
185
186
+
162
187
# Returns the size of the KV cache
163
188
def llama_get_kv_cache_size (ctx : llama_context_p ) -> c_size_t :
164
189
return _lib .llama_get_kv_cache_size (ctx )
165
190
191
+
166
192
_lib .llama_get_kv_cache_size .argtypes = [llama_context_p ]
167
193
_lib .llama_get_kv_cache_size .restype = c_size_t
168
194
195
+
169
196
# Returns the number of tokens in the KV cache
170
197
def llama_get_kv_cache_token_count (ctx : llama_context_p ) -> c_int :
171
198
return _lib .llama_get_kv_cache_token_count (ctx )
172
199
200
+
173
201
_lib .llama_get_kv_cache_token_count .argtypes = [llama_context_p ]
174
202
_lib .llama_get_kv_cache_token_count .restype = c_int
175
203
176
204
177
205
# Sets the KV cache containing the current context for the model
178
- def llama_set_kv_cache (ctx : llama_context_p , kv_cache , n_size : c_size_t , n_token_count : c_int ):
206
+ def llama_set_kv_cache (
207
+ ctx : llama_context_p , kv_cache , n_size : c_size_t , n_token_count : c_int
208
+ ):
179
209
return _lib .llama_set_kv_cache (ctx , kv_cache , n_size , n_token_count )
180
210
211
+
181
212
_lib .llama_set_kv_cache .argtypes = [llama_context_p , POINTER (c_uint8 ), c_size_t , c_int ]
182
213
_lib .llama_set_kv_cache .restype = None
183
214
0 commit comments