@@ -172,7 +172,9 @@ def llama_free(ctx: llama_context_p):
172
172
# TODO: not great API - very likely to change
173
173
# Returns 0 on success
174
174
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
175
- def llama_model_quantize (fname_inp : bytes , fname_out : bytes , ftype : c_int , nthread : c_int ) -> c_int :
175
+ def llama_model_quantize (
176
+ fname_inp : bytes , fname_out : bytes , ftype : c_int , nthread : c_int
177
+ ) -> c_int :
176
178
return _lib .llama_model_quantize (fname_inp , fname_out , ftype , nthread )
177
179
178
180
@@ -187,7 +189,10 @@ def llama_model_quantize(fname_inp: bytes, fname_out: bytes, ftype: c_int, nthre
187
189
# will be applied on top of the previous one
188
190
# Returns 0 on success
189
191
def llama_apply_lora_from_file (
190
- ctx : llama_context_p , path_lora : ctypes .c_char_p , path_base_model : ctypes .c_char_p , n_threads : c_int
192
+ ctx : llama_context_p ,
193
+ path_lora : ctypes .c_char_p ,
194
+ path_base_model : ctypes .c_char_p ,
195
+ n_threads : c_int ,
191
196
) -> c_int :
192
197
return _lib .llama_apply_lora_from_file (ctx , path_lora , path_base_model , n_threads )
193
198
@@ -235,6 +240,36 @@ def llama_set_kv_cache(
235
240
_lib .llama_set_kv_cache .restype = None
236
241
237
242
243
+ # Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
244
+ def llama_get_state_size (ctx : llama_context_p ) -> c_size_t :
245
+ return _lib .llama_get_state_size (ctx )
246
+
247
+
248
+ _lib .llama_get_state_size .argtypes = [llama_context_p ]
249
+ _lib .llama_get_state_size .restype = c_size_t
250
+
251
+
252
+ # Copies the state to the specified destination address.
253
+ # Destination needs to have allocated enough memory.
254
+ # Returns the number of bytes copied
255
+ def llama_copy_state_data (ctx : llama_context_p , dest ) -> c_size_t :
256
+ return _lib .llama_copy_state_data (ctx , dest )
257
+
258
+
259
+ _lib .llama_copy_state_data .argtypes = [llama_context_p , POINTER (c_uint8 )]
260
+ _lib .llama_copy_state_data .restype = c_size_t
261
+
262
+
263
+ # Set the state reading from the specified address
264
+ # Returns the number of bytes read
265
+ def llama_set_state_data (ctx : llama_context_p , src ) -> c_size_t :
266
+ return _lib .llama_set_state_data (ctx , src )
267
+
268
+
269
+ _lib .llama_set_state_data .argtypes = [llama_context_p , POINTER (c_uint8 )]
270
+ _lib .llama_set_state_data .restype = c_size_t
271
+
272
+
238
273
# Run the llama inference to obtain the logits and probabilities for the next token.
239
274
# tokens + n_tokens is the provided batch of new tokens to process
240
275
# n_past is the number of tokens to use from previous eval calls
0 commit comments