Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit aab74f0

Browse filesBrowse files
damian0815abetlen
andauthored
Multimodal Support (Llava 1.5) (abetlen#821)
* llava v1.5 integration * Point llama.cpp to fork * Add llava shared library target * Fix type * Update llama.cpp * Add llava api * Revert changes to llama and llama_cpp * Update llava example * Add types for new gpt-4-vision-preview api * Fix typo * Update llama.cpp * Update llama_types to match OpenAI v1 API * Update ChatCompletionFunction type * Reorder request parameters * More API type fixes * Even More Type Updates * Add parameter for custom chat_handler to Llama class * Fix circular import * Convert to absolute imports * Fix * Fix pydantic Jsontype bug * Accept list of prompt tokens in create_completion * Add llava1.5 chat handler * Add Multimodal notebook * Clean up examples * Add server docs --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
1 parent 56171cf commit aab74f0
Copy full SHA for aab74f0

File tree

Expand file treeCollapse file tree

10 files changed

+796
-102
lines changed
Filter options
Expand file treeCollapse file tree

10 files changed

+796
-102
lines changed

‎CMakeLists.txt

Copy file name to clipboardExpand all lines: CMakeLists.txt
+19Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,23 @@ if (LLAMA_BUILD)
4141
FILES $<TARGET_RUNTIME_DLLS:llama>
4242
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
4343
)
44+
add_subdirectory(vendor/llama.cpp/examples/llava)
45+
set_target_properties(llava_shared PROPERTIES OUTPUT_NAME "llava")
46+
install(
47+
TARGETS llava_shared
48+
LIBRARY DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
49+
RUNTIME DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
50+
ARCHIVE DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
51+
FRAMEWORK DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
52+
RESOURCE DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
53+
)
54+
# Temporary fix for https://github.com/scikit-build/scikit-build-core/issues/374
55+
install(
56+
TARGETS llava_shared
57+
LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
58+
RUNTIME DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
59+
ARCHIVE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
60+
FRAMEWORK DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
61+
RESOURCE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
62+
)
4463
endif()

‎docs/server.md

Copy file name to clipboard
+77Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# OpenAI Compatible Server
2+
3+
`llama-cpp-python` offers an OpenAI API compatible web server.
4+
5+
This web server can be used to serve local models and easily connect them to existing clients.
6+
7+
## Setup
8+
9+
### Installation
10+
11+
The server can be installed by running the following command:
12+
13+
```bash
14+
pip install llama-cpp-python[server]
15+
```
16+
17+
### Running the server
18+
19+
The server can then be started by running the following command:
20+
21+
```bash
22+
python3 -m llama_cpp.server --model <model_path>
23+
```
24+
25+
### Server options
26+
27+
For a full list of options, run:
28+
29+
```bash
30+
python3 -m llama_cpp.server --help
31+
```
32+
33+
NOTE: All server options are also available as environment variables. For example, `--model` can be set by setting the `MODEL` environment variable.
34+
35+
## Guides
36+
37+
### Multi-modal Models
38+
39+
`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to
40+
read information from both text and images.
41+
42+
You'll first need to download one of the available multi-modal models in GGUF format:
43+
44+
- [llava1.5 7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
45+
- [llava1.5 13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
46+
47+
Then when you run the server you'll need to also specify the path to the clip model used for image embedding
48+
49+
```bash
50+
python3 -m llama_cpp.server --model <model_path> --clip-model-path <clip_model_path>
51+
```
52+
53+
Then you can just use the OpenAI API as normal
54+
55+
```python3
56+
from openai import OpenAI
57+
58+
client = OpenAI(base_url="http://<host>:<port>/v1", api_key="sk-xxx")
59+
response = client.chat.completions.create(
60+
model="gpt-4-vision-preview",
61+
messages=[
62+
{
63+
"role": "user",
64+
"content": [
65+
{
66+
"type": "image_url",
67+
"image_url": {
68+
"url": "<image_url>"
69+
},
70+
},
71+
{"type": "text", "text": "What does the image say"},
72+
],
73+
}
74+
],
75+
)
76+
print(response)
77+
```

‎examples/notebooks/Multimodal.ipynb

Copy file name to clipboard
+84Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"ChatCompletion(id='chatcmpl-65a710ba-41d1-4d0a-a124-a44b2b4a0189', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content=' The image reads \"LlamaC++.\"', role='assistant', function_call=None, tool_calls=None))], created=1699413274, model='gpt-4-vision-preview', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=10, prompt_tokens=624, total_tokens=634))\n"
13+
]
14+
}
15+
],
16+
"source": [
17+
"from openai import OpenAI\n",
18+
"\n",
19+
"import urllib.request\n",
20+
"import base64\n",
21+
"\n",
22+
"def get_data_url(url):\n",
23+
" return \"data:image/png;base64,\" + base64.b64encode(urllib.request.urlopen(url).read()).decode(\"utf-8\")\n",
24+
"\n",
25+
"client = OpenAI(base_url=\"http://100.64.159.73:8000/v1\", api_key=\"sk-1234\")\n",
26+
"response = client.chat.completions.create(\n",
27+
" model=\"gpt-4-vision-preview\",\n",
28+
" messages=[\n",
29+
" {\n",
30+
" \"role\": \"user\",\n",
31+
" \"content\": [\n",
32+
" {\n",
33+
" \"type\": \"image_url\",\n",
34+
" \"image_url\": {\n",
35+
" \"url\": get_data_url(\"https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png\"),\n",
36+
" # \"url\": \"https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png\",\n",
37+
" },\n",
38+
" },\n",
39+
" {\"type\": \"text\", \"text\": \"What does the image say\"},\n",
40+
" ],\n",
41+
" }\n",
42+
" ],\n",
43+
")\n",
44+
"print(response)"
45+
]
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"metadata": {},
50+
"source": [
51+
"![](https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png)"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"metadata": {},
58+
"outputs": [],
59+
"source": []
60+
}
61+
],
62+
"metadata": {
63+
"kernelspec": {
64+
"display_name": ".venv",
65+
"language": "python",
66+
"name": "python3"
67+
},
68+
"language_info": {
69+
"codemirror_mode": {
70+
"name": "ipython",
71+
"version": 3
72+
},
73+
"file_extension": ".py",
74+
"mimetype": "text/x-python",
75+
"name": "python",
76+
"nbconvert_exporter": "python",
77+
"pygments_lexer": "ipython3",
78+
"version": "3.11.5+"
79+
},
80+
"orig_nbformat": 4
81+
},
82+
"nbformat": 4,
83+
"nbformat_minor": 2
84+
}

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+29-14Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import diskcache
2222
import ctypes
2323

24-
from . import llama_cpp
2524
from .llama_types import *
2625
from .llama_grammar import LlamaGrammar
26+
import llama_cpp.llama_cpp as llama_cpp
2727
import llama_cpp.llama_chat_format as llama_chat_format
2828

2929
import numpy as np
@@ -752,6 +752,7 @@ def __init__(
752752
numa: bool = False,
753753
# Chat Format Params
754754
chat_format: str = "llama-2",
755+
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
755756
# Misc
756757
verbose: bool = True,
757758
# Extra Params
@@ -784,6 +785,7 @@ def __init__(
784785
lora_path: Path to a LoRA file to apply to the model.
785786
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
786787
chat_format: String specifying the chat format to use when calling create_chat_completion.
788+
chat_handler: Optional chat handler to use when calling create_chat_completion.
787789
verbose: Print verbose output to stderr.
788790
789791
Raises:
@@ -910,6 +912,7 @@ def __init__(
910912
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
911913

912914
self.chat_format = chat_format
915+
self.chat_handler = chat_handler
913916

914917
self._n_vocab = self.n_vocab()
915918
self._n_ctx = self.n_ctx()
@@ -1231,7 +1234,7 @@ def create_embedding(
12311234
else:
12321235
inputs = input
12331236

1234-
data: List[EmbeddingData] = []
1237+
data: List[Embedding] = []
12351238
total_tokens = 0
12361239
for index, input in enumerate(inputs):
12371240
tokens = self.tokenize(input.encode("utf-8"), special=True)
@@ -1276,7 +1279,7 @@ def embed(self, input: str) -> List[float]:
12761279

12771280
def _create_completion(
12781281
self,
1279-
prompt: str,
1282+
prompt: Union[str, List[int]],
12801283
suffix: Optional[str] = None,
12811284
max_tokens: int = 16,
12821285
temperature: float = 0.8,
@@ -1297,7 +1300,9 @@ def _create_completion(
12971300
stopping_criteria: Optional[StoppingCriteriaList] = None,
12981301
logits_processor: Optional[LogitsProcessorList] = None,
12991302
grammar: Optional[LlamaGrammar] = None,
1300-
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
1303+
) -> Union[
1304+
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
1305+
]:
13011306
assert self._ctx is not None
13021307
assert suffix is None or suffix.__class__ is str
13031308

@@ -1309,7 +1314,7 @@ def _create_completion(
13091314
self.tokenize(prompt.encode("utf-8"), special=True)
13101315
if prompt != ""
13111316
else [self.token_bos()]
1312-
)
1317+
) if isinstance(prompt, str) else prompt
13131318
text: bytes = b""
13141319
returned_tokens: int = 0
13151320
stop = (
@@ -1322,7 +1327,7 @@ def _create_completion(
13221327

13231328
if len(prompt_tokens) >= self._n_ctx:
13241329
raise ValueError(
1325-
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self._ctx)}"
1330+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
13261331
)
13271332

13281333
if max_tokens <= 0:
@@ -1732,7 +1737,7 @@ def _create_completion(
17321737

17331738
def create_completion(
17341739
self,
1735-
prompt: str,
1740+
prompt: Union[str, List[int]],
17361741
suffix: Optional[str] = None,
17371742
max_tokens: int = 128,
17381743
temperature: float = 0.8,
@@ -1753,7 +1758,7 @@ def create_completion(
17531758
stopping_criteria: Optional[StoppingCriteriaList] = None,
17541759
logits_processor: Optional[LogitsProcessorList] = None,
17551760
grammar: Optional[LlamaGrammar] = None,
1756-
) -> Union[Completion, Iterator[CompletionChunk]]:
1761+
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
17571762
"""Generate text from a prompt.
17581763
17591764
Args:
@@ -1800,7 +1805,7 @@ def create_completion(
18001805
grammar=grammar,
18011806
)
18021807
if stream:
1803-
chunks: Iterator[CompletionChunk] = completion_or_chunks
1808+
chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
18041809
return chunks
18051810
completion: Completion = next(completion_or_chunks) # type: ignore
18061811
return completion
@@ -1828,7 +1833,7 @@ def __call__(
18281833
stopping_criteria: Optional[StoppingCriteriaList] = None,
18291834
logits_processor: Optional[LogitsProcessorList] = None,
18301835
grammar: Optional[LlamaGrammar] = None,
1831-
) -> Union[Completion, Iterator[CompletionChunk]]:
1836+
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
18321837
"""Generate text from a prompt.
18331838
18341839
Args:
@@ -1879,7 +1884,9 @@ def create_chat_completion(
18791884
self,
18801885
messages: List[ChatCompletionRequestMessage],
18811886
functions: Optional[List[ChatCompletionFunction]] = None,
1882-
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
1887+
function_call: Optional[ChatCompletionRequestFunctionCall] = None,
1888+
tools: Optional[List[ChatCompletionTool]] = None,
1889+
tool_choice: Optional[ChatCompletionToolChoiceOption] = None,
18831890
temperature: float = 0.2,
18841891
top_p: float = 0.95,
18851892
top_k: int = 40,
@@ -1896,7 +1903,9 @@ def create_chat_completion(
18961903
model: Optional[str] = None,
18971904
logits_processor: Optional[LogitsProcessorList] = None,
18981905
grammar: Optional[LlamaGrammar] = None,
1899-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
1906+
) -> Union[
1907+
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
1908+
]:
19001909
"""Generate a chat completion from a list of messages.
19011910
19021911
Args:
@@ -1912,12 +1921,16 @@ def create_chat_completion(
19121921
Returns:
19131922
Generated chat completion or a stream of chat completion chunks.
19141923
"""
1915-
handler = llama_chat_format.get_chat_completion_handler(self.chat_format)
1924+
handler = self.chat_handler or llama_chat_format.get_chat_completion_handler(
1925+
self.chat_format
1926+
)
19161927
return handler(
1917-
self,
1928+
llama=self,
19181929
messages=messages,
19191930
functions=functions,
19201931
function_call=function_call,
1932+
tools=tools,
1933+
tool_choice=tool_choice,
19211934
temperature=temperature,
19221935
top_p=top_p,
19231936
top_k=top_k,
@@ -1974,6 +1987,7 @@ def __getstate__(self):
19741987
numa=self.numa,
19751988
# Chat Format Params
19761989
chat_format=self.chat_format,
1990+
chat_handler=self.chat_handler,
19771991
# Misc
19781992
verbose=self.verbose,
19791993
)
@@ -2015,6 +2029,7 @@ def __setstate__(self, state):
20152029
numa=state["numa"],
20162030
# Chat Format Params
20172031
chat_format=state["chat_format"],
2032+
chat_handler=state["chat_handler"],
20182033
# Misc
20192034
verbose=state["verbose"],
20202035
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.