Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8cd64d4

Browse filesBrowse files
committed
Add rms_eps_norm
1 parent e4431a6 commit 8cd64d4
Copy full SHA for 8cd64d4

File tree

Expand file treeCollapse file tree

1 file changed

+19
-4
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+19
-4
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+19-4Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,15 @@ def __init__(
216216
embedding: bool = False,
217217
n_threads: Optional[int] = None,
218218
n_batch: int = 512,
219-
n_gqa: Optional[int] = None, # must be 8 for llama2 70b
220219
last_n_tokens_size: int = 64,
221220
lora_base: Optional[str] = None,
222221
lora_path: Optional[str] = None,
223222
low_vram: bool = False,
224223
tensor_split: Optional[List[float]] = None,
225224
rope_freq_base: float = 10000.0,
226225
rope_freq_scale: float = 1.0,
226+
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
227+
rms_eps_norm: Optional[float] = None, # (TEMPORARY)
227228
verbose: bool = True,
228229
):
229230
"""Load a llama.cpp model from `model_path`.
@@ -261,8 +262,6 @@ def __init__(
261262

262263
self.params = llama_cpp.llama_context_default_params()
263264
self.params.n_ctx = n_ctx
264-
if n_gqa is not None:
265-
self.params.n_gqa = n_gqa
266265
self.params.n_gpu_layers = n_gpu_layers
267266
self.params.seed = seed
268267
self.params.f16_kv = f16_kv
@@ -285,6 +284,12 @@ def __init__(
285284
self.params.rope_freq_base = rope_freq_base
286285
self.params.rope_freq_scale = rope_freq_scale
287286

287+
if n_gqa is not None:
288+
self.params.n_gqa = n_gqa
289+
290+
if rms_eps_norm is not None:
291+
self.params.rms_eps_norm = rms_eps_norm
292+
288293
self.last_n_tokens_size = last_n_tokens_size
289294
self.n_batch = min(n_ctx, n_batch)
290295

@@ -1526,6 +1531,10 @@ def __getstate__(self):
15261531
lora_base=self.lora_base,
15271532
lora_path=self.lora_path,
15281533
tensor_split=self.tensor_split,
1534+
### TEMPORARY ###
1535+
n_gqa=self.params.n_gqa,
1536+
rms_eps_norm=self.params.rms_eps_norm,
1537+
### TEMPORARY ###
15291538
### DEPRECATED ###
15301539
n_parts=self.n_parts,
15311540
### DEPRECATED ###
@@ -1535,7 +1544,6 @@ def __setstate__(self, state):
15351544
self.__init__(
15361545
model_path=state["model_path"],
15371546
n_ctx=state["n_ctx"],
1538-
n_parts=state["n_parts"],
15391547
n_gpu_layers=state["n_gpu_layers"],
15401548
seed=state["seed"],
15411549
f16_kv=state["f16_kv"],
@@ -1551,7 +1559,14 @@ def __setstate__(self, state):
15511559
lora_base=state["lora_base"],
15521560
lora_path=state["lora_path"],
15531561
tensor_split=state["tensor_split"],
1562+
n_gqa=state["n_gqa"],
1563+
### TEMPORARY ###
1564+
rms_eps_norm=state["rms_eps_norm"],
15541565
verbose=state["verbose"],
1566+
### TEMPORARY ###
1567+
### DEPRECATED ###
1568+
n_parts=state["n_parts"],
1569+
### DEPRECATED ###
15551570
)
15561571

15571572
def save_state(self) -> LlamaState:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.