Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f4fe27f

Browse filesBrowse files
authored
Merge pull request #5 from tc-wolf/cli_arg_kv_cache_dump_prompt_ndjson
CLI Arg to Dump Formatted Prompt Into NDJSON
2 parents 9b631db + 8d76fc1 commit f4fe27f
Copy full SHA for f4fe27f

File tree

Expand file treeCollapse file tree

7 files changed

+116
-27
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+116
-27
lines changed

‎Makefile

Copy file name to clipboardExpand all lines: Makefile
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ deploy.pyinstaller.mac:
9292

9393
# This still builds with metal support (I think b/c GGML_NATIVE=ON). Not an
9494
# issue since can still run Q4_0 models w/ repacking support on CPU if `-ngl 0`.
95+
CMAKE_BUILD_TYPE="Release" \
9596
CMAKE_ARGS="-DGGML_METAL=OFF -DGGML_LLAMAFILE=OFF -DGGML_BLAS=OFF \
96-
-DGGML_NATIVE=ON -DGGML_CPU_AARCH64=ON \
97-
-DCMAKE_BUILD_TYPE=Release" python3 -m pip install -v -e .[server,dev]
97+
-DGGML_NATIVE=ON -DGGML_CPU_AARCH64=ON" \
98+
python3 -m pip install -v -e .[server,pyinstaller]
9899
@server_path=$$(python -c 'import llama_cpp.server; print(llama_cpp.server.__file__)' | sed s/init/main/) ; \
99100
echo "Server path: $$server_path" ; \
100101
base_path=$$(python -c 'from llama_cpp._ggml import libggml_base_path; print(str(libggml_base_path))') ; \

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,12 @@ def free_lora_adapter():
546546

547547
self._sampler = None
548548

549+
# Created formatted prompt path, used for storing formatted prompts as NDJSON
550+
if (formatted_prompt_path := kwargs.get("formatted_prompt_path")) is not None:
551+
self.formatted_prompt_path = formatted_prompt_path
552+
else:
553+
self.formatted_prompt_path = None
554+
549555
@property
550556
def ctx(self) -> llama_cpp.llama_context_p:
551557
return self._ctx.ctx
@@ -2178,6 +2184,8 @@ def __getstate__(self):
21782184
# Misc
21792185
spm_infill=self.spm_infill,
21802186
verbose=self.verbose,
2187+
# Path provided for prompt serialization, if any
2188+
formatted_prompt_path=self.formatted_prompt_path,
21812189
)
21822190

21832191
def __setstate__(self, state):

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+65-22Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ctypes
77
import dataclasses
88
import random
9+
import pathlib
910
import string
1011

1112
from contextlib import ExitStack
@@ -24,6 +25,7 @@
2425

2526
import jinja2
2627
from jinja2.sandbox import ImmutableSandboxedEnvironment
28+
import filelock
2729

2830
import numpy as np
2931
import numpy.typing as npt
@@ -279,11 +281,15 @@ def _convert_text_completion_logprobs_to_chat(
279281
}
280282
for top_token, top_logprob in top_logprobs.items()
281283
],
282-
} for (token, logprob, top_logprobs) in zip(logprobs["tokens"], logprobs["token_logprobs"], logprobs["top_logprobs"])
284+
}
285+
for (token, logprob, top_logprobs) in zip(
286+
logprobs["tokens"], logprobs["token_logprobs"], logprobs["top_logprobs"]
287+
)
283288
],
284289
"refusal": None,
285290
}
286291

292+
287293
def _convert_text_completion_to_chat(
288294
completion: llama_types.Completion,
289295
) -> llama_types.ChatCompletion:
@@ -300,7 +306,9 @@ def _convert_text_completion_to_chat(
300306
"role": "assistant",
301307
"content": completion["choices"][0]["text"],
302308
},
303-
"logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
309+
"logprobs": _convert_text_completion_logprobs_to_chat(
310+
completion["choices"][0]["logprobs"]
311+
),
304312
"finish_reason": completion["choices"][0]["finish_reason"],
305313
}
306314
],
@@ -344,7 +352,9 @@ def _convert_text_completion_chunks_to_chat(
344352
if chunk["choices"][0]["finish_reason"] is None
345353
else {}
346354
),
347-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
355+
"logprobs": _convert_text_completion_logprobs_to_chat(
356+
chunk["choices"][0]["logprobs"]
357+
),
348358
"finish_reason": chunk["choices"][0]["finish_reason"],
349359
}
350360
],
@@ -407,7 +417,9 @@ def _convert_completion_to_chat_function(
407417
}
408418
],
409419
},
410-
"logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
420+
"logprobs": _convert_text_completion_logprobs_to_chat(
421+
completion["choices"][0]["logprobs"]
422+
),
411423
"finish_reason": "tool_calls",
412424
}
413425
],
@@ -460,7 +472,9 @@ def _stream_response_to_function_stream(
460472
{
461473
"index": 0,
462474
"finish_reason": None,
463-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
475+
"logprobs": _convert_text_completion_logprobs_to_chat(
476+
chunk["choices"][0]["logprobs"]
477+
),
464478
"delta": {
465479
"role": None,
466480
"content": None,
@@ -497,7 +511,9 @@ def _stream_response_to_function_stream(
497511
{
498512
"index": 0,
499513
"finish_reason": None,
500-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
514+
"logprobs": _convert_text_completion_logprobs_to_chat(
515+
chunk["choices"][0]["logprobs"]
516+
),
501517
"delta": {
502518
"role": None,
503519
"content": None,
@@ -598,6 +614,19 @@ def chat_completion_handler(
598614
add_bos=not result.added_special,
599615
special=True,
600616
)
617+
618+
# Is there a way to ensure this is not set for production? This will
619+
# slow down things at least a little (latency) because I/O is slow.
620+
if llama.formatted_prompt_path is not None:
621+
output_path = pathlib.Path(llama.formatted_prompt_path)
622+
623+
# We ensure that output path ends with .ndjson in pydantic validation.
624+
lockfile_path = output_path.with_suffix(".lock")
625+
with filelock.FileLock(str(lockfile_path)):
626+
with output_path.open("a", encoding="utf-8") as f:
627+
json.dump({"prompt": result.prompt, "prompt_tokens": prompt}, f)
628+
f.write("\n")
629+
601630
if result.stop is not None:
602631
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
603632
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
@@ -695,7 +724,7 @@ def chat_completion_handler(
695724

696725

697726
def hf_autotokenizer_to_chat_formatter(
698-
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
727+
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
699728
) -> ChatFormatter:
700729
# https://huggingface.co/docs/transformers/main/chat_templating
701730
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
@@ -720,7 +749,7 @@ def format_autotokenizer(
720749

721750

722751
def hf_autotokenizer_to_chat_completion_handler(
723-
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
752+
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
724753
) -> LlamaChatCompletionHandler:
725754
chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
726755
return chat_formatter_to_chat_completion_handler(chat_formatter)
@@ -1790,7 +1819,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
17901819
}
17911820
],
17921821
},
1793-
"logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
1822+
"logprobs": _convert_text_completion_logprobs_to_chat(
1823+
completion["choices"][0]["logprobs"]
1824+
),
17941825
"finish_reason": "tool_calls",
17951826
}
17961827
],
@@ -2202,7 +2233,9 @@ def generate_streaming(tools, functions, function_call, prompt):
22022233
choices=[
22032234
{
22042235
"index": 0,
2205-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2236+
"logprobs": _convert_text_completion_logprobs_to_chat(
2237+
chunk["choices"][0]["logprobs"]
2238+
),
22062239
"delta": {
22072240
"role": None,
22082241
"content": None,
@@ -2304,7 +2337,9 @@ def generate_streaming(tools, functions, function_call, prompt):
23042337
choices=[
23052338
{
23062339
"index": 0,
2307-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2340+
"logprobs": _convert_text_completion_logprobs_to_chat(
2341+
chunk["choices"][0]["logprobs"]
2342+
),
23082343
"delta": {
23092344
"role": "assistant",
23102345
"content": None,
@@ -2342,7 +2377,9 @@ def generate_streaming(tools, functions, function_call, prompt):
23422377
choices=[
23432378
{
23442379
"index": 0,
2345-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2380+
"logprobs": _convert_text_completion_logprobs_to_chat(
2381+
chunk["choices"][0]["logprobs"]
2382+
),
23462383
"delta": {
23472384
"role": "assistant",
23482385
"content": buffer.pop(0),
@@ -2365,7 +2402,9 @@ def generate_streaming(tools, functions, function_call, prompt):
23652402
choices=[
23662403
{
23672404
"index": 0,
2368-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2405+
"logprobs": _convert_text_completion_logprobs_to_chat(
2406+
chunk["choices"][0]["logprobs"]
2407+
),
23692408
"delta": {
23702409
"role": "assistant",
23712410
"content": (
@@ -2451,7 +2490,9 @@ def generate_streaming(tools, functions, function_call, prompt):
24512490
choices=[
24522491
{
24532492
"index": 0,
2454-
"logprobs": _convert_text_completion_logprobs_to_chat(chunk["choices"][0]["logprobs"]),
2493+
"logprobs": _convert_text_completion_logprobs_to_chat(
2494+
chunk["choices"][0]["logprobs"]
2495+
),
24552496
"delta": {
24562497
"role": None,
24572498
"content": None,
@@ -2685,7 +2726,9 @@ def generate_streaming(tools, functions, function_call, prompt):
26852726
choices=[
26862727
{
26872728
"index": 0,
2688-
"logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
2729+
"logprobs": _convert_text_completion_logprobs_to_chat(
2730+
completion["choices"][0]["logprobs"]
2731+
),
26892732
"message": {
26902733
"role": "assistant",
26912734
"content": None if content == "" else content,
@@ -2795,9 +2838,7 @@ def _embed_image_bytes(self, image_bytes: bytes, n_threads_batch: int = 1):
27952838
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
27962839
self.clip_ctx,
27972840
n_threads_batch,
2798-
(ctypes.c_uint8 * len(image_bytes)).from_buffer(
2799-
bytearray(image_bytes)
2800-
),
2841+
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
28012842
len(image_bytes),
28022843
)
28032844
self._last_image_embed = embed
@@ -2869,7 +2910,6 @@ def __call__(
28692910
if self.verbose:
28702911
print(text, file=sys.stderr)
28712912

2872-
28732913
# Evaluate prompt
28742914
llama.reset()
28752915
llama._ctx.kv_cache_clear()
@@ -2885,7 +2925,9 @@ def __call__(
28852925
llama.eval(tokens)
28862926
else:
28872927
image_bytes = self.load_image(value)
2888-
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
2928+
embed = self._embed_image_bytes(
2929+
image_bytes, llama.context_params.n_threads_batch
2930+
)
28892931
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
28902932
raise ValueError(
28912933
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
@@ -3404,7 +3446,6 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
34043446
"{% endif %}"
34053447
"{% endif %}"
34063448
"{% endfor %}"
3407-
34083449
"{% for content in message['content'] %}"
34093450
"{% if content.type == 'text' %}"
34103451
"{{ content.text }}"
@@ -3817,7 +3858,9 @@ def chatml_function_calling(
38173858
{
38183859
"finish_reason": "tool_calls",
38193860
"index": 0,
3820-
"logprobs": _convert_text_completion_logprobs_to_chat(completion["choices"][0]["logprobs"]),
3861+
"logprobs": _convert_text_completion_logprobs_to_chat(
3862+
completion["choices"][0]["logprobs"]
3863+
),
38213864
"message": {
38223865
"role": "assistant",
38233866
"content": None,

‎llama_cpp/server/model.py

Copy file name to clipboardExpand all lines: llama_cpp/server/model.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
223223
import functools
224224

225225
kwargs = {}
226+
# Move this here so that works w/ llama_cpp.Llama.from_pretrained as
227+
# well as 'normal' constructor.
228+
kwargs["formatted_prompt_path"] = settings.formatted_prompt_path
226229

227230
if settings.hf_model_repo_id is not None:
228231
create_fn = functools.partial(

‎llama_cpp/server/settings.py

Copy file name to clipboardExpand all lines: llama_cpp/server/settings.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@ class ModelSettings(BaseSettings):
188188
default=None,
189189
description="Type of the value cache quantization.",
190190
)
191+
192+
# Path to store formatted prompts as NDJSON
193+
formatted_prompt_path: Optional[str] = Field(
194+
default=None,
195+
pattern=r".*\.ndjson$",
196+
description="Output path to store formatted prompts as NDJSON.",
197+
)
198+
191199
# Misc
192200
verbose: bool = Field(
193201
default=True, description="Whether to print debug information."

‎pyproject.toml

Copy file name to clipboardExpand all lines: pyproject.toml
+5-3Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ dependencies = [
1515
"diskcache>=5.6.1",
1616
"jinja2>=2.11.3",
1717
"PyTrie>=0.4.0",
18+
"filelock>=3.18.0",
1819
]
19-
requires-python = ">=3.8"
20+
requires-python = ">=3.9"
2021
classifiers = [
2122
"Programming Language :: Python :: 3",
22-
"Programming Language :: Python :: 3.8",
2323
"Programming Language :: Python :: 3.9",
2424
"Programming Language :: Python :: 3.10",
2525
"Programming Language :: Python :: 3.11",
@@ -56,9 +56,11 @@ dev = [
5656
"httpx>=0.24.1",
5757
"pandas>=2.2.1",
5858
"tqdm>=4.66.2",
59+
]
60+
pyinstaller = [
5961
"pyinstaller>=6.11.1",
6062
]
61-
all = ["llama_cpp_python[server,test,dev]"]
63+
all = ["llama_cpp_python[server,test,dev,pyinstaller]"]
6264

6365
[tool.scikit-build]
6466
wheel.packages = ["llama_cpp"]

‎tests/test_settings.py

Copy file name to clipboard
+24Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
3+
from llama_cpp.server.settings import ModelSettings
4+
from pydantic import ValidationError
5+
6+
# Required to pass in model name
7+
DUMMY_MODEL_NAME = "foo"
8+
9+
10+
def test_formatted_prompt_path_default_none():
11+
m = ModelSettings(model=DUMMY_MODEL_NAME)
12+
assert m.formatted_prompt_path is None
13+
14+
15+
def test_validation_error_if_prompt_path_not_endswith_ndjson():
16+
with pytest.raises(
17+
ValidationError, match=r"String should match pattern '.*\\.ndjson\$'"
18+
):
19+
ModelSettings(model=DUMMY_MODEL_NAME, formatted_prompt_path="invalid_path.txt")
20+
21+
22+
def test_formatted_prompt_path_works_if_endswith_ndjson():
23+
m = ModelSettings(model=DUMMY_MODEL_NAME, formatted_prompt_path="valid_path.ndjson")
24+
assert m.formatted_prompt_path == "valid_path.ndjson"

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.