@@ -234,6 +234,22 @@ class llama_context_params(Structure):
234
234
LLAMA_FTYPE_MOSTLY_Q6_K = c_int (18 )
235
235
236
236
237
+ # // model quantization parameters
238
+ # typedef struct llama_model_quantize_params {
239
+ # int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
240
+ # enum llama_ftype ftype; // quantize to this llama_ftype
241
+ # bool allow_requantize; // allow quantizing non-f32/f16 tensors
242
+ # bool quantize_output_tensor; // quantize output.weight
243
+ # } llama_model_quantize_params;
244
+ class llama_model_quantize_params (Structure ):
245
+ _fields_ = [
246
+ ("nthread" , c_int ),
247
+ ("ftype" , c_int ),
248
+ ("allow_requantize" , c_bool ),
249
+ ("quantize_output_tensor" , c_bool ),
250
+ ]
251
+
252
+
237
253
# LLAMA_API struct llama_context_params llama_context_default_params();
238
254
def llama_context_default_params () -> llama_context_params :
239
255
return _lib .llama_context_default_params ()
@@ -243,6 +259,15 @@ def llama_context_default_params() -> llama_context_params:
243
259
_lib .llama_context_default_params .restype = llama_context_params
244
260
245
261
262
+ # LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params();
263
+ def llama_model_quantize_default_params () -> llama_model_quantize_params :
264
+ return _lib .llama_model_quantize_default_params ()
265
+
266
+
267
+ _lib .llama_model_quantize_default_params .argtypes = []
268
+ _lib .llama_model_quantize_default_params .restype = llama_model_quantize_params
269
+
270
+
246
271
# LLAMA_API bool llama_mmap_supported();
247
272
def llama_mmap_supported () -> bool :
248
273
return _lib .llama_mmap_supported ()
@@ -308,21 +333,24 @@ def llama_free(ctx: llama_context_p):
308
333
_lib .llama_free .restype = None
309
334
310
335
311
- # TODO: not great API - very likely to change
312
- # Returns 0 on success
313
- # nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
336
+ # // Returns 0 on success
314
337
# LLAMA_API int llama_model_quantize(
315
338
# const char * fname_inp,
316
339
# const char * fname_out,
317
- # enum llama_ftype ftype,
318
- # int nthread);
340
+ # const llama_model_quantize_params * params);
319
341
def llama_model_quantize (
320
- fname_inp : bytes , fname_out : bytes , ftype : c_int , nthread : c_int
342
+ fname_inp : bytes ,
343
+ fname_out : bytes ,
344
+ params , # type: POINTER(llama_model_quantize_params) # type: ignore
321
345
) -> int :
322
- return _lib .llama_model_quantize (fname_inp , fname_out , ftype , nthread )
346
+ return _lib .llama_model_quantize (fname_inp , fname_out , params )
323
347
324
348
325
- _lib .llama_model_quantize .argtypes = [c_char_p , c_char_p , c_int , c_int ]
349
+ _lib .llama_model_quantize .argtypes = [
350
+ c_char_p ,
351
+ c_char_p ,
352
+ POINTER (llama_model_quantize_params ),
353
+ ]
326
354
_lib .llama_model_quantize .restype = c_int
327
355
328
356
0 commit comments