Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f165048

Browse filesBrowse files
Limour-devabetlen
andauthored
feat: add support for KV cache quantization options (abetlen#1307)
* add KV cache quantization options abetlen#1220 abetlen#1305 * Add ggml_type * Use ggml_type instead of string for quantization * Add server support --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
1 parent aa9f1ae commit f165048
Copy full SHA for f165048

File tree

Expand file treeCollapse file tree

4 files changed

+94
-41
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+94
-41
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+18-41Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def __init__(
105105
draft_model: Optional[LlamaDraftModel] = None,
106106
# Tokenizer Override
107107
tokenizer: Optional[BaseLlamaTokenizer] = None,
108+
# KV cache quantization
109+
type_k: Optional[int] = None,
110+
type_v: Optional[int] = None,
108111
# Misc
109112
verbose: bool = True,
110113
# Extra Params
@@ -172,6 +175,8 @@ def __init__(
172175
draft_model: Optional draft model to use for speculative decoding.
173176
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
174177
verbose: Print verbose output to stderr.
178+
type_k: KV cache data type for K (default: f16)
179+
type_v: KV cache data type for V (default: f16)
175180
176181
Raises:
177182
ValueError: If the model path does not exist.
@@ -298,7 +303,11 @@ def __init__(
298303
) # Must be set to True for speculative decoding
299304
self.context_params.embeddings = embedding # TODO: Rename to embeddings
300305
self.context_params.offload_kqv = offload_kqv
301-
306+
# KV cache quantization
307+
if type_k is not None:
308+
self.context_params.type_k = type_k
309+
if type_v is not None:
310+
self.context_params.type_v = type_v
302311
# Sampling Params
303312
self.last_n_tokens_size = last_n_tokens_size
304313

@@ -1724,6 +1733,7 @@ def __getstate__(self):
17241733
n_threads=self.context_params.n_threads,
17251734
n_threads_batch=self.context_params.n_threads_batch,
17261735
rope_scaling_type=self.context_params.rope_scaling_type,
1736+
pooling_type=self.context_params.pooling_type,
17271737
rope_freq_base=self.context_params.rope_freq_base,
17281738
rope_freq_scale=self.context_params.rope_freq_scale,
17291739
yarn_ext_factor=self.context_params.yarn_ext_factor,
@@ -1733,6 +1743,7 @@ def __getstate__(self):
17331743
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
17341744
logits_all=self.context_params.logits_all,
17351745
embedding=self.context_params.embeddings,
1746+
offload_kqv=self.context_params.offload_kqv,
17361747
# Sampling Params
17371748
last_n_tokens_size=self.last_n_tokens_size,
17381749
# LoRA Params
@@ -1744,51 +1755,17 @@ def __getstate__(self):
17441755
# Chat Format Params
17451756
chat_format=self.chat_format,
17461757
chat_handler=self.chat_handler,
1758+
# Speculative Decidng
1759+
draft_model=self.draft_model,
1760+
# KV cache quantization
1761+
type_k=self.context_params.type_k,
1762+
type_v=self.context_params.type_v,
17471763
# Misc
17481764
verbose=self.verbose,
17491765
)
17501766

17511767
def __setstate__(self, state):
1752-
self.__init__(
1753-
model_path=state["model_path"],
1754-
# Model Params
1755-
n_gpu_layers=state["n_gpu_layers"],
1756-
split_mode=state["split_mode"],
1757-
main_gpu=state["main_gpu"],
1758-
tensor_split=state["tensor_split"],
1759-
vocab_only=state["vocab_only"],
1760-
use_mmap=state["use_mmap"],
1761-
use_mlock=state["use_mlock"],
1762-
kv_overrides=state["kv_overrides"],
1763-
# Context Params
1764-
seed=state["seed"],
1765-
n_ctx=state["n_ctx"],
1766-
n_batch=state["n_batch"],
1767-
n_threads=state["n_threads"],
1768-
n_threads_batch=state["n_threads_batch"],
1769-
rope_freq_base=state["rope_freq_base"],
1770-
rope_freq_scale=state["rope_freq_scale"],
1771-
rope_scaling_type=state["rope_scaling_type"],
1772-
yarn_ext_factor=state["yarn_ext_factor"],
1773-
yarn_attn_factor=state["yarn_attn_factor"],
1774-
yarn_beta_fast=state["yarn_beta_fast"],
1775-
yarn_beta_slow=state["yarn_beta_slow"],
1776-
yarn_orig_ctx=state["yarn_orig_ctx"],
1777-
logits_all=state["logits_all"],
1778-
embedding=state["embedding"],
1779-
# Sampling Params
1780-
last_n_tokens_size=state["last_n_tokens_size"],
1781-
# LoRA Params
1782-
lora_base=state["lora_base"],
1783-
lora_path=state["lora_path"],
1784-
# Backend Params
1785-
numa=state["numa"],
1786-
# Chat Format Params
1787-
chat_format=state["chat_format"],
1788-
chat_handler=state["chat_handler"],
1789-
# Misc
1790-
verbose=state["verbose"],
1791-
)
1768+
self.__init__(**state)
17921769

17931770
def save_state(self) -> LlamaState:
17941771
assert self._ctx.ctx is not None

‎llama_cpp/llama_cpp.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cpp.py
+64Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,70 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa
141141

142142
byref = ctypes.byref # type: ignore
143143

144+
# from ggml.h
145+
# // NOTE: always add types at the end of the enum to keep backward compatibility
146+
# enum ggml_type {
147+
# GGML_TYPE_F32 = 0,
148+
# GGML_TYPE_F16 = 1,
149+
# GGML_TYPE_Q4_0 = 2,
150+
# GGML_TYPE_Q4_1 = 3,
151+
# // GGML_TYPE_Q4_2 = 4, support has been removed
152+
# // GGML_TYPE_Q4_3 = 5, support has been removed
153+
# GGML_TYPE_Q5_0 = 6,
154+
# GGML_TYPE_Q5_1 = 7,
155+
# GGML_TYPE_Q8_0 = 8,
156+
# GGML_TYPE_Q8_1 = 9,
157+
# GGML_TYPE_Q2_K = 10,
158+
# GGML_TYPE_Q3_K = 11,
159+
# GGML_TYPE_Q4_K = 12,
160+
# GGML_TYPE_Q5_K = 13,
161+
# GGML_TYPE_Q6_K = 14,
162+
# GGML_TYPE_Q8_K = 15,
163+
# GGML_TYPE_IQ2_XXS = 16,
164+
# GGML_TYPE_IQ2_XS = 17,
165+
# GGML_TYPE_IQ3_XXS = 18,
166+
# GGML_TYPE_IQ1_S = 19,
167+
# GGML_TYPE_IQ4_NL = 20,
168+
# GGML_TYPE_IQ3_S = 21,
169+
# GGML_TYPE_IQ2_S = 22,
170+
# GGML_TYPE_IQ4_XS = 23,
171+
# GGML_TYPE_I8 = 24,
172+
# GGML_TYPE_I16 = 25,
173+
# GGML_TYPE_I32 = 26,
174+
# GGML_TYPE_I64 = 27,
175+
# GGML_TYPE_F64 = 28,
176+
# GGML_TYPE_IQ1_M = 29,
177+
# GGML_TYPE_COUNT,
178+
# };
179+
GGML_TYPE_F32 = 0
180+
GGML_TYPE_F16 = 1
181+
GGML_TYPE_Q4_0 = 2
182+
GGML_TYPE_Q4_1 = 3
183+
GGML_TYPE_Q5_0 = 6
184+
GGML_TYPE_Q5_1 = 7
185+
GGML_TYPE_Q8_0 = 8
186+
GGML_TYPE_Q8_1 = 9
187+
GGML_TYPE_Q2_K = 10
188+
GGML_TYPE_Q3_K = 11
189+
GGML_TYPE_Q4_K = 12
190+
GGML_TYPE_Q5_K = 13
191+
GGML_TYPE_Q6_K = 14
192+
GGML_TYPE_Q8_K = 15
193+
GGML_TYPE_IQ2_XXS = 16
194+
GGML_TYPE_IQ2_XS = 17
195+
GGML_TYPE_IQ3_XXS = 18
196+
GGML_TYPE_IQ1_S = 19
197+
GGML_TYPE_IQ4_NL = 20
198+
GGML_TYPE_IQ3_S = 21
199+
GGML_TYPE_IQ2_S = 22
200+
GGML_TYPE_IQ4_XS = 23
201+
GGML_TYPE_I8 = 24
202+
GGML_TYPE_I16 = 25
203+
GGML_TYPE_I32 = 26
204+
GGML_TYPE_I64 = 27
205+
GGML_TYPE_F64 = 28
206+
GGML_TYPE_IQ1_M = 29
207+
GGML_TYPE_COUNT = 30
144208

145209
# from ggml-backend.h
146210
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);

‎llama_cpp/server/model.py

Copy file name to clipboardExpand all lines: llama_cpp/server/model.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
175175
chat_handler=chat_handler,
176176
# Speculative Decoding
177177
draft_model=draft_model,
178+
# KV Cache Quantization
179+
type_k=settings.type_k,
180+
type_v=settings.type_v,
178181
# Tokenizer
179182
tokenizer=tokenizer,
180183
# Misc

‎llama_cpp/server/settings.py

Copy file name to clipboardExpand all lines: llama_cpp/server/settings.py
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ class ModelSettings(BaseSettings):
159159
default=10,
160160
description="Number of tokens to predict using the draft model.",
161161
)
162+
# KV Cache Quantization
163+
type_k: Optional[int] = Field(
164+
default=None,
165+
description="Type of the key cache quantization.",
166+
)
167+
type_v: Optional[int] = Field(
168+
default=None,
169+
description="Type of the value cache quantization.",
170+
)
162171
# Misc
163172
verbose: bool = Field(
164173
default=True, description="Whether to print debug information."

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.