Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c4c440b

Browse filesBrowse files
committed
Fix tensor_split cli option
1 parent 203ede4 commit c4c440b
Copy full SHA for c4c440b

File tree

Expand file treeCollapse file tree

2 files changed

+16
-7
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+16
-7
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def __init__(
288288

289289
if self.tensor_split is not None:
290290
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
291-
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
291+
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
292292
self._c_tensor_split = FloatArray(
293293
*tensor_split
294294
) # keep a reference to the array so it is not gc'd

‎llama_cpp/server/__main__.py

Copy file name to clipboardExpand all lines: llama_cpp/server/__main__.py
+15-6Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,35 @@
2323
"""
2424
import os
2525
import argparse
26-
from typing import Literal, Union
26+
from typing import List, Literal, Union
2727

2828
import uvicorn
2929

3030
from llama_cpp.server.app import create_app, Settings
3131

32-
def get_non_none_base_types(annotation):
33-
if not hasattr(annotation, "__args__"):
34-
return annotation
35-
return [arg for arg in annotation.__args__ if arg is not type(None)][0]
36-
3732
def get_base_type(annotation):
3833
if getattr(annotation, '__origin__', None) is Literal:
3934
return type(annotation.__args__[0])
4035
elif getattr(annotation, '__origin__', None) is Union:
4136
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
4237
if non_optional_args:
4338
return get_base_type(non_optional_args[0])
39+
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
40+
return get_base_type(annotation.__args__[0])
4441
else:
4542
return annotation
4643

44+
def contains_list_type(annotation) -> bool:
45+
origin = getattr(annotation, '__origin__', None)
46+
47+
if origin is list or origin is List:
48+
return True
49+
elif origin in (Literal, Union):
50+
return any(contains_list_type(arg) for arg in annotation.__args__)
51+
else:
52+
return False
53+
54+
4755
if __name__ == "__main__":
4856
parser = argparse.ArgumentParser()
4957
for name, field in Settings.model_fields.items():
@@ -53,6 +61,7 @@ def get_base_type(annotation):
5361
parser.add_argument(
5462
f"--{name}",
5563
dest=name,
64+
nargs="*" if contains_list_type(field.annotation) else None,
5665
type=get_base_type(field.annotation) if field.annotation is not None else str,
5766
help=description,
5867
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.