Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 44558cb

Browse filesBrowse files
committed
misc: llava_cpp use ctypes function decorator for binding
1 parent 8383a9e commit 44558cb
Copy full SHA for 44558cb

File tree

Expand file treeCollapse file tree

1 file changed

+34
-28
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+34
-28
lines changed

‎llama_cpp/llava_cpp.py

Copy file name to clipboardExpand all lines: llama_cpp/llava_cpp.py
+34-28Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import os
33
import ctypes
4+
import functools
45
from ctypes import (
56
c_bool,
67
c_char_p,
@@ -13,7 +14,7 @@
1314
Structure,
1415
)
1516
import pathlib
16-
from typing import List, Union, NewType, Optional
17+
from typing import List, Union, NewType, Optional, TypeVar, Callable, Any
1718

1819
import llama_cpp.llama_cpp as llama_cpp
1920

@@ -76,6 +77,31 @@ def _load_shared_library(lib_base_name: str):
7677
# Load the library
7778
_libllava = _load_shared_library(_libllava_base_name)
7879

80+
# ctypes helper
81+
82+
F = TypeVar("F", bound=Callable[..., Any])
83+
84+
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
85+
def ctypes_function(
86+
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
87+
):
88+
def decorator(f: F) -> F:
89+
if enabled:
90+
func = getattr(lib, name)
91+
func.argtypes = argtypes
92+
func.restype = restype
93+
functools.wraps(f)(func)
94+
return func
95+
else:
96+
return f
97+
98+
return decorator
99+
100+
return ctypes_function
101+
102+
103+
ctypes_function = ctypes_function_for_shared_library(_libllava)
104+
79105

80106
################################################
81107
# llava.h
@@ -97,49 +123,35 @@ class llava_image_embed(Structure):
97123

98124
# /** sanity check for clip <-> llava embed size match */
99125
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
126+
@ctypes_function("llava_validate_embed_size", [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], c_bool)
100127
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
101128
...
102129

103-
llava_validate_embed_size = _libllava.llava_validate_embed_size
104-
llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes]
105-
llava_validate_embed_size.restype = c_bool
106130

107131
# /** build an image embed from image file bytes */
108132
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
133+
@ctypes_function("llava_image_embed_make_with_bytes", [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], POINTER(llava_image_embed))
109134
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
110135
...
111136

112-
llava_image_embed_make_with_bytes = _libllava.llava_image_embed_make_with_bytes
113-
llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int]
114-
llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
115-
116137
# /** build an image embed from a path to an image filename */
117138
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
139+
@ctypes_function("llava_image_embed_make_with_filename", [clip_ctx_p_ctypes, c_int, c_char_p], POINTER(llava_image_embed))
118140
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
119141
...
120142

121-
llava_image_embed_make_with_filename = _libllava.llava_image_embed_make_with_filename
122-
llava_image_embed_make_with_filename.argtypes = [clip_ctx_p_ctypes, c_int, c_char_p]
123-
llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
124-
125143
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
126144
# /** free an embedding made with llava_image_embed_make_* */
145+
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
127146
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
128147
...
129148

130-
llava_image_embed_free = _libllava.llava_image_embed_free
131-
llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
132-
llava_image_embed_free.restype = None
133-
134149
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
135150
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
151+
@ctypes_function("llava_eval_image_embed", [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)], c_bool)
136152
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
137153
...
138154

139-
llava_eval_image_embed = _libllava.llava_eval_image_embed
140-
llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)]
141-
llava_eval_image_embed.restype = c_bool
142-
143155

144156
################################################
145157
# clip.h
@@ -148,18 +160,12 @@ def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointe
148160

149161
# /** load mmproj model */
150162
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
163+
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
151164
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
152165
...
153166

154-
clip_model_load = _libllava.clip_model_load
155-
clip_model_load.argtypes = [c_char_p, c_int]
156-
clip_model_load.restype = clip_ctx_p_ctypes
157-
158167
# /** free mmproj model */
159168
# CLIP_API void clip_free(struct clip_ctx * ctx);
169+
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
160170
def clip_free(ctx: clip_ctx_p, /):
161171
...
162-
163-
clip_free = _libllava.clip_free
164-
clip_free.argtypes = [clip_ctx_p_ctypes]
165-
clip_free.restype = None

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.