Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3241af8

Browse filesBrowse files
authored
Merge branch 'abetlen:main' into main
2 parents c37c34d + cfb7da9 commit 3241af8
Copy full SHA for 3241af8

File tree

Expand file treeCollapse file tree

13 files changed

+285
-46
lines changed
Filter options
Expand file treeCollapse file tree

13 files changed

+285
-46
lines changed

‎CHANGELOG.md

Copy file name to clipboardExpand all lines: CHANGELOG.md
+9Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.2.29]
11+
12+
- feat: Update llama.cpp to ggerganov/llama.cpp@4483396751c79dea540808b9cb9238245d06da2b
13+
- feat: Add split_mode option by @abetlen in 84615adbc6855c8384807c42f0130f9a1763f99d
14+
- feat: Implement GGUF metadata KV overrides by @phiharri in #1011
15+
- fix: Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor by @yieldthought in #1012
16+
- fix: Fix low_level_api_chat_cpp example to match current API by @aniljava in #1086
17+
- fix: Fix Pydantic model parsing by @DeNeutoy in #1087
18+
1019
## [0.2.28]
1120

1221
- feat: Update llama.cpp to ggerganov/llama.cpp@6efb8eb30e7025b168f3fda3ff83b9b386428ad6

‎examples/low_level_api/common.py

Copy file name to clipboardExpand all lines: examples/low_level_api/common.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def gpt_params_parse(argv = None):
106106
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
107107

108108
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
109-
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
109+
parser.add_argument("-p", "--prompt", type=str, default=None, help="initial prompt",dest="prompt")
110110
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
111111
parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
112112
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")

‎examples/low_level_api/low_level_api_chat_cpp.py

Copy file name to clipboardExpand all lines: examples/low_level_api/low_level_api_chat_cpp.py
+36-14Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, params: GptParams) -> None:
6262
self.multibyte_fix = []
6363

6464
# model load
65-
self.lparams = llama_cpp.llama_context_default_params()
65+
self.lparams = llama_cpp.llama_model_default_params()
6666
self.lparams.n_ctx = self.params.n_ctx
6767
self.lparams.n_parts = self.params.n_parts
6868
self.lparams.seed = self.params.seed
@@ -72,7 +72,11 @@ def __init__(self, params: GptParams) -> None:
7272

7373
self.model = llama_cpp.llama_load_model_from_file(
7474
self.params.model.encode("utf8"), self.lparams)
75-
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
75+
76+
# Context Params.
77+
self.cparams = llama_cpp.llama_context_default_params()
78+
79+
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams)
7680
if (not self.ctx):
7781
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
7882

@@ -244,7 +248,7 @@ def __init__(self, params: GptParams) -> None:
244248
# tokenize a prompt
245249
def _tokenize(self, prompt, bos=True):
246250
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
247-
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
251+
_n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False)
248252
return _arr[:_n]
249253

250254
def set_color(self, c):
@@ -304,7 +308,7 @@ def generate(self):
304308
self.n_past += n_eval"""
305309

306310
if (llama_cpp.llama_eval(
307-
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
311+
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past
308312
) != 0):
309313
raise Exception("Failed to llama_eval!")
310314

@@ -332,7 +336,7 @@ def generate(self):
332336
id = 0
333337

334338
logits = llama_cpp.llama_get_logits(self.ctx)
335-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
339+
n_vocab = llama_cpp.llama_n_vocab(self.model)
336340

337341
# Apply params.logit_bias map
338342
for key, value in self.params.logit_bias.items():
@@ -349,12 +353,20 @@ def generate(self):
349353
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
350354

351355
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
352-
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
353-
_arr,
354-
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
355-
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
356-
_arr,
357-
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
356+
llama_cpp.llama_sample_repetition_penalties(
357+
ctx=self.ctx,
358+
candidates=candidates_p,
359+
last_tokens_data = _arr,
360+
penalty_last_n = last_n_repeat,
361+
penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty),
362+
penalty_freq = llama_cpp.c_float(self.params.frequency_penalty),
363+
penalty_present = llama_cpp.c_float(self.params.presence_penalty),
364+
)
365+
366+
# NOT PRESENT IN CURRENT VERSION ?
367+
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
368+
# _arr,
369+
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
358370

359371
if not self.params.penalize_nl:
360372
logits[llama_cpp.llama_token_nl()] = nl_logit
@@ -473,7 +485,7 @@ def exit(self):
473485
def token_to_str(self, token_id: int) -> bytes:
474486
size = 32
475487
buffer = (ctypes.c_char * size)()
476-
n = llama_cpp.llama_token_to_piece_with_model(
488+
n = llama_cpp.llama_token_to_piece(
477489
self.model, llama_cpp.llama_token(token_id), buffer, size)
478490
assert n <= size
479491
return bytes(buffer[:n])
@@ -532,6 +544,9 @@ def interact(self):
532544
print(i,end="",flush=True)
533545
self.params.input_echo = False
534546

547+
# Using string instead of tokens to check for antiprompt,
548+
# It is more reliable than tokens for interactive mode.
549+
generated_str = ""
535550
while self.params.interactive:
536551
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
537552
if (self.params.instruct):
@@ -546,6 +561,10 @@ def interact(self):
546561
try:
547562
for i in self.output():
548563
print(i,end="",flush=True)
564+
generated_str += i
565+
for ap in self.params.antiprompt:
566+
if generated_str.endswith(ap):
567+
raise KeyboardInterrupt
549568
except KeyboardInterrupt:
550569
self.set_color(util.CONSOLE_COLOR_DEFAULT)
551570
if not self.params.instruct:
@@ -561,7 +580,7 @@ def interact(self):
561580
time_now = datetime.now()
562581
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
563582
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision.
564-
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
583+
Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
565584
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
566585
The transcript only includes text, it does not include markup like HTML and Markdown.
567586
@@ -575,8 +594,11 @@ def interact(self):
575594
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
576595
{USER_NAME}: Name a color.
577596
{AI_NAME}: Blue
578-
{USER_NAME}:"""
597+
{USER_NAME}: """
598+
579599
params = gpt_params_parse()
600+
if params.prompt is None and params.file is None:
601+
params.prompt = prompt
580602

581603
with LLaMAInteract(params) as m:
582604
m.interact()

‎llama_cpp/__init__.py

Copy file name to clipboard
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .llama_cpp import *
22
from .llama import *
33

4-
__version__ = "0.2.28"
4+
__version__ = "0.2.29"

‎llama_cpp/_utils.py

Copy file name to clipboardExpand all lines: llama_cpp/_utils.py
+9-11Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import os
22
import sys
33

4+
import sys, traceback
5+
6+
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
7+
outnull_file = open(os.devnull, "w")
8+
errnull_file = open(os.devnull, "w")
49

510
class suppress_stdout_stderr(object):
611
# NOTE: these must be "saved" here to avoid exceptions when using
712
# this context manager inside of a __del__ method
8-
open = open
913
sys = sys
1014
os = os
1115

@@ -21,9 +25,6 @@ def __enter__(self):
2125
if not hasattr(self.sys.stdout, 'fileno') or not hasattr(self.sys.stderr, 'fileno'):
2226
return self # Return the instance without making changes
2327

24-
self.outnull_file = self.open(self.os.devnull, "w")
25-
self.errnull_file = self.open(self.os.devnull, "w")
26-
2728
self.old_stdout_fileno_undup = self.sys.stdout.fileno()
2829
self.old_stderr_fileno_undup = self.sys.stderr.fileno()
2930

@@ -33,11 +34,11 @@ def __enter__(self):
3334
self.old_stdout = self.sys.stdout
3435
self.old_stderr = self.sys.stderr
3536

36-
self.os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
37-
self.os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
37+
self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
38+
self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
3839

39-
self.sys.stdout = self.outnull_file
40-
self.sys.stderr = self.errnull_file
40+
self.sys.stdout = outnull_file
41+
self.sys.stderr = errnull_file
4142
return self
4243

4344
def __exit__(self, *_):
@@ -54,6 +55,3 @@ def __exit__(self, *_):
5455

5556
self.os.close(self.old_stdout_fileno)
5657
self.os.close(self.old_stderr_fileno)
57-
58-
self.outnull_file.close()
59-
self.errnull_file.close()

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+38-1Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,11 +730,13 @@ def __init__(
730730
*,
731731
# Model Params
732732
n_gpu_layers: int = 0,
733+
split_mode: int = llama_cpp.LLAMA_SPLIT_LAYER,
733734
main_gpu: int = 0,
734735
tensor_split: Optional[List[float]] = None,
735736
vocab_only: bool = False,
736737
use_mmap: bool = True,
737738
use_mlock: bool = False,
739+
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
738740
# Context Params
739741
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
740742
n_ctx: int = 512,
@@ -798,11 +800,13 @@ def __init__(
798800
Args:
799801
model_path: Path to the model.
800802
n_gpu_layers: Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
801-
main_gpu: The GPU that is used for scratch and small tensors.
803+
split_mode: How to split the model across GPUs. See llama_cpp.LLAMA_SPLIT_* for options.
804+
main_gpu: main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model. LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results. LLAMA_SPLIT_LAYER: ignored
802805
tensor_split: How split tensors should be distributed across GPUs. If None, the model is not split.
803806
vocab_only: Only load the vocabulary no weights.
804807
use_mmap: Use mmap if possible.
805808
use_mlock: Force the system to keep the model in RAM.
809+
kv_overrides: Key-value overrides for the model.
806810
seed: RNG seed, -1 for random
807811
n_ctx: Text context, 0 = from model
808812
n_batch: Prompt processing maximum batch size
@@ -848,6 +852,7 @@ def __init__(
848852
self.model_params.n_gpu_layers = (
849853
0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers
850854
) # 0x7FFFFFFF is INT32 max, will be auto set to all layers
855+
self.model_params.split_mode = split_mode
851856
self.model_params.main_gpu = main_gpu
852857
self.tensor_split = tensor_split
853858
self._c_tensor_split = None
@@ -866,6 +871,34 @@ def __init__(
866871
self.model_params.use_mmap = use_mmap if lora_path is None else False
867872
self.model_params.use_mlock = use_mlock
868873

874+
self.kv_overrides = kv_overrides
875+
if kv_overrides is not None:
876+
n_overrides = len(kv_overrides)
877+
self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1)
878+
self._kv_overrides_array_keys = []
879+
880+
for k, v in kv_overrides.items():
881+
key_buf = ctypes.create_string_buffer(k.encode("utf-8"))
882+
self._kv_overrides_array_keys.append(key_buf)
883+
self._kv_overrides_array[i].key = key_buf
884+
if isinstance(v, int):
885+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
886+
self._kv_overrides_array[i].value.int_value = v
887+
elif isinstance(v, float):
888+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT
889+
self._kv_overrides_array[i].value.float_value = v
890+
elif isinstance(v, bool):
891+
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL
892+
self._kv_overrides_array[i].value.bool_value = v
893+
else:
894+
raise ValueError(f"Unknown value type for {k}: {v}")
895+
896+
self._kv_overrides_array_sentinel_key = b'\0'
897+
898+
# null array sentinel
899+
self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key
900+
self.model_params.kv_overrides = self._kv_overrides_array
901+
869902
self.n_batch = min(n_ctx, n_batch) # ???
870903
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
871904
self.n_threads_batch = n_threads_batch or max(
@@ -2143,11 +2176,13 @@ def __getstate__(self):
21432176
model_path=self.model_path,
21442177
# Model Params
21452178
n_gpu_layers=self.model_params.n_gpu_layers,
2179+
split_mode=self.model_params.split_mode,
21462180
main_gpu=self.model_params.main_gpu,
21472181
tensor_split=self.tensor_split,
21482182
vocab_only=self.model_params.vocab_only,
21492183
use_mmap=self.model_params.use_mmap,
21502184
use_mlock=self.model_params.use_mlock,
2185+
kv_overrides=self.kv_overrides,
21512186
# Context Params
21522187
seed=self.context_params.seed,
21532188
n_ctx=self.context_params.n_ctx,
@@ -2185,11 +2220,13 @@ def __setstate__(self, state):
21852220
model_path=state["model_path"],
21862221
# Model Params
21872222
n_gpu_layers=state["n_gpu_layers"],
2223+
split_mode=state["split_mode"],
21882224
main_gpu=state["main_gpu"],
21892225
tensor_split=state["tensor_split"],
21902226
vocab_only=state["vocab_only"],
21912227
use_mmap=state["use_mmap"],
21922228
use_mlock=state["use_mlock"],
2229+
kv_overrides=state["kv_overrides"],
21932230
# Context Params
21942231
seed=state["seed"],
21952232
n_ctx=state["n_ctx"],

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.