3
3
import multiprocessing
4
4
5
5
from typing import Optional , List , Literal , Union
6
- from pydantic import Field
6
+ from pydantic import Field , root_validator
7
7
from pydantic_settings import BaseSettings
8
8
9
9
import llama_cpp
@@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
67
67
n_threads : int = Field (
68
68
default = max (multiprocessing .cpu_count () // 2 , 1 ),
69
69
ge = 1 ,
70
- description = "The number of threads to use." ,
70
+ description = "The number of threads to use. Use -1 for max cpu threads " ,
71
71
)
72
72
n_threads_batch : int = Field (
73
73
default = max (multiprocessing .cpu_count (), 1 ),
74
74
ge = 0 ,
75
- description = "The number of threads to use when batch processing." ,
75
+ description = "The number of threads to use when batch processing. Use -1 for max cpu threads " ,
76
76
)
77
77
rope_scaling_type : int = Field (
78
78
default = llama_cpp .LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
@@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
173
173
default = True , description = "Whether to print debug information."
174
174
)
175
175
176
+ @root_validator (pre = True ) # pre=True to ensure this runs before any other validation
177
+ def set_dynamic_defaults (cls , values ):
178
+ # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
179
+ cpu_count = multiprocessing .cpu_count ()
180
+ if values .get ('n_threads' , 0 ) == - 1 :
181
+ values ['n_threads' ] = cpu_count
182
+ if values .get ('n_threads_batch' , 0 ) == - 1 :
183
+ values ['n_threads_batch' ] = cpu_count
184
+ return values
185
+
176
186
177
187
class ServerSettings (BaseSettings ):
178
188
"""Server settings used to configure the FastAPI and Uvicorn server."""
0 commit comments