Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fcb8051

Browse filesBrowse files
committed
Use ggml_type instead of string for quantization
1 parent 1a6a9a3 commit fcb8051
Copy full SHA for fcb8051

File tree

Expand file treeCollapse file tree

1 file changed

+15
-55
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+15
-55
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+15-55Lines changed: 15 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ def __init__(
105105
draft_model: Optional[LlamaDraftModel] = None,
106106
# Tokenizer Override
107107
tokenizer: Optional[BaseLlamaTokenizer] = None,
108+
# KV cache quantization
109+
type_k: Optional[int] = None,
110+
type_v: Optional[int] = None,
108111
# Misc
109112
verbose: bool = True,
110-
# KV cache quantization
111-
type_k: str = 'f16',
112-
type_v: str = 'f16',
113113
# Extra Params
114114
**kwargs, # type: ignore
115115
):
@@ -304,18 +304,10 @@ def __init__(
304304
self.context_params.embeddings = embedding # TODO: Rename to embeddings
305305
self.context_params.offload_kqv = offload_kqv
306306
# KV cache quantization
307-
kv_cache_type = {
308-
'f32': 0,
309-
'f16': 1,
310-
'q8_0': 8,
311-
'q4_0': 2,
312-
'q4_1': 3,
313-
'iq4_nl': 20,
314-
'q5_0': 6,
315-
'q5_1': 7
316-
}
317-
self.context_params.type_k = kv_cache_type[type_k]
318-
self.context_params.type_v = kv_cache_type[type_v]
307+
if type_k is not None:
308+
self.context_params.type_k = type_k
309+
if type_v is not None:
310+
self.context_params.type_v = type_v
319311
# Sampling Params
320312
self.last_n_tokens_size = last_n_tokens_size
321313

@@ -1741,6 +1733,7 @@ def __getstate__(self):
17411733
n_threads=self.context_params.n_threads,
17421734
n_threads_batch=self.context_params.n_threads_batch,
17431735
rope_scaling_type=self.context_params.rope_scaling_type,
1736+
pooling_type=self.context_params.pooling_type,
17441737
rope_freq_base=self.context_params.rope_freq_base,
17451738
rope_freq_scale=self.context_params.rope_freq_scale,
17461739
yarn_ext_factor=self.context_params.yarn_ext_factor,
@@ -1750,6 +1743,7 @@ def __getstate__(self):
17501743
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
17511744
logits_all=self.context_params.logits_all,
17521745
embedding=self.context_params.embeddings,
1746+
offload_kqv=self.context_params.offload_kqv,
17531747
# Sampling Params
17541748
last_n_tokens_size=self.last_n_tokens_size,
17551749
# LoRA Params
@@ -1761,51 +1755,17 @@ def __getstate__(self):
17611755
# Chat Format Params
17621756
chat_format=self.chat_format,
17631757
chat_handler=self.chat_handler,
1758+
# Speculative Decidng
1759+
draft_model=self.draft_model,
1760+
# KV cache quantization
1761+
type_k=self.context_params.type_k,
1762+
type_v=self.context_params.type_v,
17641763
# Misc
17651764
verbose=self.verbose,
17661765
)
17671766

17681767
def __setstate__(self, state):
1769-
self.__init__(
1770-
model_path=state["model_path"],
1771-
# Model Params
1772-
n_gpu_layers=state["n_gpu_layers"],
1773-
split_mode=state["split_mode"],
1774-
main_gpu=state["main_gpu"],
1775-
tensor_split=state["tensor_split"],
1776-
vocab_only=state["vocab_only"],
1777-
use_mmap=state["use_mmap"],
1778-
use_mlock=state["use_mlock"],
1779-
kv_overrides=state["kv_overrides"],
1780-
# Context Params
1781-
seed=state["seed"],
1782-
n_ctx=state["n_ctx"],
1783-
n_batch=state["n_batch"],
1784-
n_threads=state["n_threads"],
1785-
n_threads_batch=state["n_threads_batch"],
1786-
rope_freq_base=state["rope_freq_base"],
1787-
rope_freq_scale=state["rope_freq_scale"],
1788-
rope_scaling_type=state["rope_scaling_type"],
1789-
yarn_ext_factor=state["yarn_ext_factor"],
1790-
yarn_attn_factor=state["yarn_attn_factor"],
1791-
yarn_beta_fast=state["yarn_beta_fast"],
1792-
yarn_beta_slow=state["yarn_beta_slow"],
1793-
yarn_orig_ctx=state["yarn_orig_ctx"],
1794-
logits_all=state["logits_all"],
1795-
embedding=state["embedding"],
1796-
# Sampling Params
1797-
last_n_tokens_size=state["last_n_tokens_size"],
1798-
# LoRA Params
1799-
lora_base=state["lora_base"],
1800-
lora_path=state["lora_path"],
1801-
# Backend Params
1802-
numa=state["numa"],
1803-
# Chat Format Params
1804-
chat_format=state["chat_format"],
1805-
chat_handler=state["chat_handler"],
1806-
# Misc
1807-
verbose=state["verbose"],
1808-
)
1768+
self.__init__(**state)
18091769

18101770
def save_state(self) -> LlamaState:
18111771
assert self._ctx.ctx is not None

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.