@@ -105,11 +105,11 @@ def __init__(
105
105
draft_model : Optional [LlamaDraftModel ] = None ,
106
106
# Tokenizer Override
107
107
tokenizer : Optional [BaseLlamaTokenizer ] = None ,
108
+ # KV cache quantization
109
+ type_k : Optional [int ] = None ,
110
+ type_v : Optional [int ] = None ,
108
111
# Misc
109
112
verbose : bool = True ,
110
- # KV cache quantization
111
- type_k : str = 'f16' ,
112
- type_v : str = 'f16' ,
113
113
# Extra Params
114
114
** kwargs , # type: ignore
115
115
):
@@ -304,18 +304,10 @@ def __init__(
304
304
self .context_params .embeddings = embedding # TODO: Rename to embeddings
305
305
self .context_params .offload_kqv = offload_kqv
306
306
# KV cache quantization
307
- kv_cache_type = {
308
- 'f32' : 0 ,
309
- 'f16' : 1 ,
310
- 'q8_0' : 8 ,
311
- 'q4_0' : 2 ,
312
- 'q4_1' : 3 ,
313
- 'iq4_nl' : 20 ,
314
- 'q5_0' : 6 ,
315
- 'q5_1' : 7
316
- }
317
- self .context_params .type_k = kv_cache_type [type_k ]
318
- self .context_params .type_v = kv_cache_type [type_v ]
307
+ if type_k is not None :
308
+ self .context_params .type_k = type_k
309
+ if type_v is not None :
310
+ self .context_params .type_v = type_v
319
311
# Sampling Params
320
312
self .last_n_tokens_size = last_n_tokens_size
321
313
@@ -1741,6 +1733,7 @@ def __getstate__(self):
1741
1733
n_threads = self .context_params .n_threads ,
1742
1734
n_threads_batch = self .context_params .n_threads_batch ,
1743
1735
rope_scaling_type = self .context_params .rope_scaling_type ,
1736
+ pooling_type = self .context_params .pooling_type ,
1744
1737
rope_freq_base = self .context_params .rope_freq_base ,
1745
1738
rope_freq_scale = self .context_params .rope_freq_scale ,
1746
1739
yarn_ext_factor = self .context_params .yarn_ext_factor ,
@@ -1750,6 +1743,7 @@ def __getstate__(self):
1750
1743
yarn_orig_ctx = self .context_params .yarn_orig_ctx ,
1751
1744
logits_all = self .context_params .logits_all ,
1752
1745
embedding = self .context_params .embeddings ,
1746
+ offload_kqv = self .context_params .offload_kqv ,
1753
1747
# Sampling Params
1754
1748
last_n_tokens_size = self .last_n_tokens_size ,
1755
1749
# LoRA Params
@@ -1761,51 +1755,17 @@ def __getstate__(self):
1761
1755
# Chat Format Params
1762
1756
chat_format = self .chat_format ,
1763
1757
chat_handler = self .chat_handler ,
1758
+ # Speculative Decidng
1759
+ draft_model = self .draft_model ,
1760
+ # KV cache quantization
1761
+ type_k = self .context_params .type_k ,
1762
+ type_v = self .context_params .type_v ,
1764
1763
# Misc
1765
1764
verbose = self .verbose ,
1766
1765
)
1767
1766
1768
1767
def __setstate__ (self , state ):
1769
- self .__init__ (
1770
- model_path = state ["model_path" ],
1771
- # Model Params
1772
- n_gpu_layers = state ["n_gpu_layers" ],
1773
- split_mode = state ["split_mode" ],
1774
- main_gpu = state ["main_gpu" ],
1775
- tensor_split = state ["tensor_split" ],
1776
- vocab_only = state ["vocab_only" ],
1777
- use_mmap = state ["use_mmap" ],
1778
- use_mlock = state ["use_mlock" ],
1779
- kv_overrides = state ["kv_overrides" ],
1780
- # Context Params
1781
- seed = state ["seed" ],
1782
- n_ctx = state ["n_ctx" ],
1783
- n_batch = state ["n_batch" ],
1784
- n_threads = state ["n_threads" ],
1785
- n_threads_batch = state ["n_threads_batch" ],
1786
- rope_freq_base = state ["rope_freq_base" ],
1787
- rope_freq_scale = state ["rope_freq_scale" ],
1788
- rope_scaling_type = state ["rope_scaling_type" ],
1789
- yarn_ext_factor = state ["yarn_ext_factor" ],
1790
- yarn_attn_factor = state ["yarn_attn_factor" ],
1791
- yarn_beta_fast = state ["yarn_beta_fast" ],
1792
- yarn_beta_slow = state ["yarn_beta_slow" ],
1793
- yarn_orig_ctx = state ["yarn_orig_ctx" ],
1794
- logits_all = state ["logits_all" ],
1795
- embedding = state ["embedding" ],
1796
- # Sampling Params
1797
- last_n_tokens_size = state ["last_n_tokens_size" ],
1798
- # LoRA Params
1799
- lora_base = state ["lora_base" ],
1800
- lora_path = state ["lora_path" ],
1801
- # Backend Params
1802
- numa = state ["numa" ],
1803
- # Chat Format Params
1804
- chat_format = state ["chat_format" ],
1805
- chat_handler = state ["chat_handler" ],
1806
- # Misc
1807
- verbose = state ["verbose" ],
1808
- )
1768
+ self .__init__ (** state )
1809
1769
1810
1770
def save_state (self ) -> LlamaState :
1811
1771
assert self ._ctx .ctx is not None
0 commit comments