@@ -79,6 +79,10 @@ def _load_shared_library(lib_base_name: str):
7979
8080# llama.h bindings
8181
82+ GGML_USE_CUBLAS = hasattr (_lib , "ggml_init_cublas" )
83+ GGML_CUDA_MAX_DEVICES = ctypes .c_int (16 )
84+ LLAMA_MAX_DEVICES = GGML_CUDA_MAX_DEVICES if GGML_USE_CUBLAS else ctypes .c_int (1 )
85+
8286# #define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
8387LLAMA_FILE_MAGIC_GGJT = ctypes .c_uint (0x67676A74 )
8488# #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
@@ -142,9 +146,12 @@ class llama_token_data_array(Structure):
142146
143147
144148# struct llama_context_params {
145- # int n_ctx; // text context
146- # int n_gpu_layers; // number of layers to store in VRAM
147- # int seed; // RNG seed, -1 for random
149+ # int n_ctx; // text context
150+ # int n_batch; // prompt processing batch size
151+ # int n_gpu_layers; // number of layers to store in VRAM
152+ # int main_gpu; // the GPU that is used for scratch and small tensors
153+ # float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
154+ # int seed; // RNG seed, -1 for random
148155
149156# bool f16_kv; // use fp16 for KV cache
150157# bool logits_all; // the llama_eval() call computes all logits, not just the last one
@@ -153,7 +160,6 @@ class llama_token_data_array(Structure):
153160# bool use_mlock; // force system to keep model in RAM
154161# bool embedding; // embedding mode only
155162
156-
157163# // called with a progress value between 0 and 1, pass NULL to disable
158164# llama_progress_callback progress_callback;
159165# // context pointer passed to the progress callback
@@ -162,7 +168,10 @@ class llama_token_data_array(Structure):
162168class llama_context_params (Structure ):
163169 _fields_ = [
164170 ("n_ctx" , c_int ),
171+ ("n_batch" , c_int ),
165172 ("n_gpu_layers" , c_int ),
173+ ("main_gpu" , c_int ),
174+ ("tensor_split" , c_float * LLAMA_MAX_DEVICES .value ),
166175 ("seed" , c_int ),
167176 ("f16_kv" , c_bool ),
168177 (
0 commit comments