Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6c44a3f

Browse filesBrowse files
committed
feat: Add option to configure n_ubatch
1 parent 47d7a62 commit 6c44a3f
Copy full SHA for 6c44a3f

File tree

Expand file treeCollapse file tree

3 files changed

+9
-0
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+9
-0
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
7676
n_ctx: int = 512,
7777
n_batch: int = 512,
78+
n_ubatch: int = 512,
7879
n_threads: Optional[int] = None,
7980
n_threads_batch: Optional[int] = None,
8081
rope_scaling_type: Optional[
@@ -156,6 +157,7 @@ def __init__(
156157
seed: RNG seed, -1 for random
157158
n_ctx: Text context, 0 = from model
158159
n_batch: Prompt processing maximum batch size
160+
n_ubatch: Physical batch size
159161
n_threads: Number of threads to use for generation
160162
n_threads_batch: Number of threads to use for batch processing
161163
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -309,6 +311,7 @@ def __init__(
309311
self.context_params = llama_cpp.llama_context_default_params()
310312
self.context_params.n_ctx = n_ctx
311313
self.context_params.n_batch = self.n_batch
314+
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
312315
self.context_params.n_threads = self.n_threads
313316
self.context_params.n_threads_batch = self.n_threads_batch
314317
self.context_params.rope_scaling_type = (
@@ -380,6 +383,7 @@ def __init__(
380383
self.n_batch = min(n_ctx, n_batch)
381384
self.context_params.n_ctx = self._model.n_ctx_train()
382385
self.context_params.n_batch = self.n_batch
386+
self.context_params.n_ubatch = min(self.n_batch, n_ubatch)
383387

384388
self._ctx = self._stack.enter_context(
385389
contextlib.closing(
@@ -2071,6 +2075,7 @@ def __getstate__(self):
20712075
seed=self.context_params.seed,
20722076
n_ctx=self.context_params.n_ctx,
20732077
n_batch=self.n_batch,
2078+
n_ubatch=self.context_params.n_ubatch,
20742079
n_threads=self.context_params.n_threads,
20752080
n_threads_batch=self.context_params.n_threads_batch,
20762081
rope_scaling_type=self.context_params.rope_scaling_type,

‎llama_cpp/server/model.py

Copy file name to clipboardExpand all lines: llama_cpp/server/model.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
249249
seed=settings.seed,
250250
n_ctx=settings.n_ctx,
251251
n_batch=settings.n_batch,
252+
n_ubatch=settings.n_ubatch,
252253
n_threads=settings.n_threads,
253254
n_threads_batch=settings.n_threads_batch,
254255
rope_scaling_type=settings.rope_scaling_type,

‎llama_cpp/server/settings.py

Copy file name to clipboardExpand all lines: llama_cpp/server/settings.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class ModelSettings(BaseSettings):
7070
n_batch: int = Field(
7171
default=512, ge=1, description="The batch size to use per eval."
7272
)
73+
n_ubatch: int = Field(
74+
default=512, ge=1, description="The physical batch size used by llama.cpp"
75+
)
7376
n_threads: int = Field(
7477
default=max(multiprocessing.cpu_count() // 2, 1),
7578
ge=1,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.