Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 25b3494

Browse filesBrowse files
committed
Minor fix to tensor_split parameter
1 parent e6c67c8 commit 25b3494
Copy full SHA for 25b3494

File tree

Expand file treeCollapse file tree

1 file changed

+13
-10
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+13
-10
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+13-10Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def __init__(
207207
n_ctx: int = 512,
208208
n_parts: int = -1,
209209
n_gpu_layers: int = 0,
210-
tensor_split: list[float] = None,
211210
seed: int = 1337,
212211
f16_kv: bool = True,
213212
logits_all: bool = False,
@@ -221,6 +220,7 @@ def __init__(
221220
lora_base: Optional[str] = None,
222221
lora_path: Optional[str] = None,
223222
low_vram: bool = False,
223+
tensor_split: Optional[List[float]] = None,
224224
verbose: bool = True,
225225
):
226226
"""Load a llama.cpp model from `model_path`.
@@ -241,6 +241,7 @@ def __init__(
241241
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
242242
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
243243
lora_path: Path to a LoRA file to apply to the model.
244+
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
244245
verbose: Print verbose output to stderr.
245246
246247
Raises:
@@ -249,20 +250,13 @@ def __init__(
249250
Returns:
250251
A Llama instance.
251252
"""
252-
if tensor_split is None:
253-
tensor_split = [0.0] * llama_cpp.LLAMA_MAX_DEVICES.value
254-
255-
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
256-
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
257-
c_tensor_split = FloatArray(*tensor_split)
258253

259254
self.verbose = verbose
260255
self.model_path = model_path
261256

262257
self.params = llama_cpp.llama_context_default_params()
263258
self.params.n_ctx = n_ctx
264259
self.params.n_gpu_layers = n_gpu_layers
265-
self.params.tensor_split = c_tensor_split
266260
self.params.seed = seed
267261
self.params.f16_kv = f16_kv
268262
self.params.logits_all = logits_all
@@ -272,6 +266,15 @@ def __init__(
272266
self.params.embedding = embedding
273267
self.params.low_vram = low_vram
274268

269+
self.tensor_split = tensor_split
270+
self._c_tensor_split = None
271+
272+
if self.tensor_split is not None:
273+
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
274+
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
275+
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
276+
self.params.tensor_split = self._c_tensor_split
277+
275278
self.last_n_tokens_size = last_n_tokens_size
276279
self.n_batch = min(n_ctx, n_batch)
277280

@@ -1499,7 +1502,6 @@ def __getstate__(self):
14991502
model_path=self.model_path,
15001503
n_ctx=self.params.n_ctx,
15011504
n_gpu_layers=self.params.n_gpu_layers,
1502-
tensor_split=self.params.tensor_split,
15031505
seed=self.params.seed,
15041506
f16_kv=self.params.f16_kv,
15051507
logits_all=self.params.logits_all,
@@ -1513,6 +1515,7 @@ def __getstate__(self):
15131515
n_threads=self.n_threads,
15141516
lora_base=self.lora_base,
15151517
lora_path=self.lora_path,
1518+
tensor_split=self.tensor_split,
15161519
### DEPRECATED ###
15171520
n_parts=self.n_parts,
15181521
### DEPRECATED ###
@@ -1524,7 +1527,6 @@ def __setstate__(self, state):
15241527
n_ctx=state["n_ctx"],
15251528
n_parts=state["n_parts"],
15261529
n_gpu_layers=state["n_gpu_layers"],
1527-
tensor_split=state["tensor_split"],
15281530
seed=state["seed"],
15291531
f16_kv=state["f16_kv"],
15301532
logits_all=state["logits_all"],
@@ -1538,6 +1540,7 @@ def __setstate__(self, state):
15381540
last_n_tokens_size=state["last_n_tokens_size"],
15391541
lora_base=state["lora_base"],
15401542
lora_path=state["lora_path"],
1543+
tensor_split=state["tensor_split"],
15411544
verbose=state["verbose"],
15421545
)
15431546

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.