6
6
import time
7
7
import json
8
8
import ctypes
9
+ import typing
9
10
import fnmatch
10
11
import multiprocessing
11
12
@@ -249,24 +250,26 @@ def __init__(
249
250
self ._kv_overrides_array [i ].key = k .encode ("utf-8" )
250
251
if isinstance (v , bool ):
251
252
self ._kv_overrides_array [i ].tag = llama_cpp .LLAMA_KV_OVERRIDE_TYPE_BOOL
252
- self ._kv_overrides_array [i ].value .bool_value = v
253
+ self ._kv_overrides_array [i ].value .val_bool = v
253
254
elif isinstance (v , int ):
254
255
self ._kv_overrides_array [i ].tag = llama_cpp .LLAMA_KV_OVERRIDE_TYPE_INT
255
- self ._kv_overrides_array [i ].value .int_value = v
256
+ self ._kv_overrides_array [i ].value .val_i64 = v
256
257
elif isinstance (v , float ):
257
258
self ._kv_overrides_array [i ].tag = llama_cpp .LLAMA_KV_OVERRIDE_TYPE_FLOAT
258
- self ._kv_overrides_array [i ].value .float_value = v
259
+ self ._kv_overrides_array [i ].value .val_f64 = v
259
260
elif isinstance (v , str ): # type: ignore
260
261
v_bytes = v .encode ("utf-8" )
261
262
if len (v_bytes ) > 128 : # TODO: Make this a constant
262
263
raise ValueError (f"Value for { k } is too long: { v } " )
263
264
v_bytes = v_bytes .ljust (128 , b"\0 " )
264
265
self ._kv_overrides_array [i ].tag = llama_cpp .LLAMA_KV_OVERRIDE_TYPE_STR
265
266
# copy min(v_bytes, 128) to str_value
267
+ address = typing .cast (int , ctypes .addressof (self ._kv_overrides_array [i ].value ) + llama_cpp .llama_model_kv_override_value .val_str .offset )
268
+ buffer_start = ctypes .cast (address , ctypes .POINTER (ctypes .c_char ))
266
269
ctypes .memmove (
267
- self . _kv_overrides_array [ i ]. value . str_value ,
270
+ buffer_start ,
268
271
v_bytes ,
269
- min ( len ( v_bytes ), 128 ) ,
272
+ 128 ,
270
273
)
271
274
else :
272
275
raise ValueError (f"Unknown value type for { k } : { v } " )
0 commit comments