Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e665b55

Browse filesBrowse files
authored
Merge pull request abetlen#523 from shouyiwang/tensor_split
Update tensor_split to match llama.cpp's change
2 parents d3bf7db + 426dbfe commit e665b55
Copy full SHA for e665b55

File tree

Expand file treeCollapse file tree

1 file changed

+4
-5
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+4
-5
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+4-5Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,12 @@ def __init__(
273273
self.params.low_vram = low_vram
274274

275275
self.tensor_split = tensor_split
276-
self._c_tensor_split = None
276+
self._p_tensor_split = None
277277

278278
if self.tensor_split is not None:
279-
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
280-
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
281-
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
282-
self.params.tensor_split = self._c_tensor_split
279+
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
280+
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
281+
self.params.tensor_split = self._p_tensor_split
283282

284283
self.params.rope_freq_base = rope_freq_base
285284
self.params.rope_freq_scale = rope_freq_scale

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.