|
| 1 | +import sys |
| 2 | +import os |
| 3 | +import ctypes |
| 4 | +from ctypes import ( |
| 5 | + c_bool, |
| 6 | + c_char_p, |
| 7 | + c_int, |
| 8 | + c_int8, |
| 9 | + c_int32, |
| 10 | + c_uint8, |
| 11 | + c_uint32, |
| 12 | + c_size_t, |
| 13 | + c_float, |
| 14 | + c_double, |
| 15 | + c_void_p, |
| 16 | + POINTER, |
| 17 | + _Pointer, # type: ignore |
| 18 | + Structure, |
| 19 | + Array, |
| 20 | +) |
| 21 | +import pathlib |
| 22 | +from typing import List, Union |
| 23 | + |
| 24 | +import llama_cpp.llama_cpp as llama_cpp |
| 25 | + |
| 26 | +# Load the library |
| 27 | +def _load_shared_library(lib_base_name: str): |
| 28 | + # Construct the paths to the possible shared library names |
| 29 | + _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) |
| 30 | + # Searching for the library in the current directory under the name "libllama" (default name |
| 31 | + # for llamacpp) and "llama" (default name for this repo) |
| 32 | + _lib_paths: List[pathlib.Path] = [] |
| 33 | + # Determine the file extension based on the platform |
| 34 | + if sys.platform.startswith("linux"): |
| 35 | + _lib_paths += [ |
| 36 | + _base_path / f"lib{lib_base_name}.so", |
| 37 | + ] |
| 38 | + elif sys.platform == "darwin": |
| 39 | + _lib_paths += [ |
| 40 | + _base_path / f"lib{lib_base_name}.so", |
| 41 | + _base_path / f"lib{lib_base_name}.dylib", |
| 42 | + ] |
| 43 | + elif sys.platform == "win32": |
| 44 | + _lib_paths += [ |
| 45 | + _base_path / f"{lib_base_name}.dll", |
| 46 | + _base_path / f"lib{lib_base_name}.dll", |
| 47 | + ] |
| 48 | + else: |
| 49 | + raise RuntimeError("Unsupported platform") |
| 50 | + |
| 51 | + if "LLAMA_CPP_LIB" in os.environ: |
| 52 | + lib_base_name = os.environ["LLAMA_CPP_LIB"] |
| 53 | + _lib = pathlib.Path(lib_base_name) |
| 54 | + _base_path = _lib.parent.resolve() |
| 55 | + _lib_paths = [_lib.resolve()] |
| 56 | + |
| 57 | + cdll_args = dict() # type: ignore |
| 58 | + # Add the library directory to the DLL search path on Windows (if needed) |
| 59 | + if sys.platform == "win32" and sys.version_info >= (3, 8): |
| 60 | + os.add_dll_directory(str(_base_path)) |
| 61 | + if "CUDA_PATH" in os.environ: |
| 62 | + os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin")) |
| 63 | + os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib")) |
| 64 | + cdll_args["winmode"] = ctypes.RTLD_GLOBAL |
| 65 | + |
| 66 | + # Try to load the shared library, handling potential errors |
| 67 | + for _lib_path in _lib_paths: |
| 68 | + if _lib_path.exists(): |
| 69 | + try: |
| 70 | + return ctypes.CDLL(str(_lib_path), **cdll_args) |
| 71 | + except Exception as e: |
| 72 | + raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") |
| 73 | + |
| 74 | + raise FileNotFoundError( |
| 75 | + f"Shared library with base name '{lib_base_name}' not found" |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +# Specify the base name of the shared library to load |
| 80 | +_libllava_base_name = "llava" |
| 81 | + |
| 82 | +# Load the library |
| 83 | +_libllava = _load_shared_library(_libllava_base_name) |
| 84 | + |
| 85 | + |
| 86 | +################################################ |
| 87 | +# llava.h |
| 88 | +################################################ |
| 89 | + |
| 90 | +# struct clip_ctx; |
| 91 | +clip_ctx_p = c_void_p |
| 92 | + |
| 93 | +# struct llava_image_embed { |
| 94 | +# float * embed; |
| 95 | +# int n_image_pos; |
| 96 | +# }; |
| 97 | +class llava_image_embed(Structure): |
| 98 | + _fields_ = [ |
| 99 | + ("embed", POINTER(c_float)), |
| 100 | + ("n_image_pos", c_int), |
| 101 | + ] |
| 102 | + |
| 103 | +# /** sanity check for clip <-> llava embed size match */ |
| 104 | +# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); |
| 105 | +def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool: |
| 106 | + return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip) |
| 107 | + |
| 108 | +_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p] |
| 109 | +_libllava.llava_validate_embed_size.restype = c_bool |
| 110 | + |
| 111 | +# /** build an image embed from image file bytes */ |
| 112 | +# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); |
| 113 | +def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]": |
| 114 | + return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length) |
| 115 | + |
| 116 | +_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int] |
| 117 | +_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed) |
| 118 | + |
| 119 | +# /** build an image embed from a path to an image filename */ |
| 120 | +# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); |
| 121 | +def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]": |
| 122 | + return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path) |
| 123 | + |
| 124 | +_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p] |
| 125 | +_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed) |
| 126 | + |
| 127 | +# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); |
| 128 | +# /** free an embedding made with llava_image_embed_make_* */ |
| 129 | +def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"): |
| 130 | + return _libllava.llava_image_embed_free(embed) |
| 131 | + |
| 132 | +_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)] |
| 133 | +_libllava.llava_image_embed_free.restype = None |
| 134 | + |
| 135 | +# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ |
| 136 | +# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); |
| 137 | +def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: Union[c_int, int]) -> bool: |
| 138 | + return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past) |
| 139 | + |
| 140 | +_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)] |
| 141 | +_libllava.llava_eval_image_embed.restype = c_bool |
| 142 | + |
| 143 | + |
| 144 | +################################################ |
| 145 | +# clip.h |
| 146 | +################################################ |
| 147 | + |
| 148 | + |
| 149 | +# struct clip_vision_hparams { |
| 150 | +# int32_t image_size; |
| 151 | +# int32_t patch_size; |
| 152 | +# int32_t hidden_size; |
| 153 | +# int32_t n_intermediate; |
| 154 | +# int32_t projection_dim; |
| 155 | +# int32_t n_head; |
| 156 | +# int32_t n_layer; |
| 157 | +# float eps; |
| 158 | +# }; |
| 159 | +class clip_vision_hparams(Structure): |
| 160 | + _fields_ = [ |
| 161 | + ("image_size", c_int32), |
| 162 | + ("patch_size", c_int32), |
| 163 | + ("hidden_size", c_int32), |
| 164 | + ("n_intermediate", c_int32), |
| 165 | + ("projection_dim", c_int32), |
| 166 | + ("n_head", c_int32), |
| 167 | + ("n_layer", c_int32), |
| 168 | + ("eps", c_float), |
| 169 | + ] |
| 170 | + |
| 171 | +# /** load mmproj model */ |
| 172 | +# CLIP_API struct clip_ctx * clip_model_load(const char * fname, const int verbosity); |
| 173 | +def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p: |
| 174 | + return _libllava.clip_model_load(fname, verbosity) |
| 175 | + |
| 176 | +_libllava.clip_model_load.argtypes = [c_char_p, c_int] |
| 177 | +_libllava.clip_model_load.restype = clip_ctx_p |
| 178 | + |
| 179 | +# /** free mmproj model */ |
| 180 | +# CLIP_API void clip_free(struct clip_ctx * ctx); |
| 181 | +def clip_free(ctx: clip_ctx_p): |
| 182 | + return _libllava.clip_free(ctx) |
| 183 | + |
| 184 | +_libllava.clip_free.argtypes = [clip_ctx_p] |
| 185 | +_libllava.clip_free.restype = None |
| 186 | + |
| 187 | +# size_t clip_embd_nbytes(const struct clip_ctx * ctx); |
| 188 | +# int clip_n_patches(const struct clip_ctx * ctx); |
| 189 | +# int clip_n_mmproj_embd(const struct clip_ctx * ctx); |
| 190 | + |
| 191 | +# // RGB uint8 image |
| 192 | +# struct clip_image_u8 { |
| 193 | +# int nx; |
| 194 | +# int ny; |
| 195 | +# uint8_t * data = NULL; |
| 196 | +# size_t size; |
| 197 | +# }; |
| 198 | + |
| 199 | +# // RGB float32 image (NHWC) |
| 200 | +# // Memory layout: RGBRGBRGB... |
| 201 | +# struct clip_image_f32 { |
| 202 | +# int nx; |
| 203 | +# int ny; |
| 204 | +# float * data = NULL; |
| 205 | +# size_t size; |
| 206 | +# }; |
| 207 | + |
| 208 | +# struct clip_image_u8_batch { |
| 209 | +# struct clip_image_u8 * data; |
| 210 | +# size_t size; |
| 211 | +# }; |
| 212 | + |
| 213 | +# struct clip_image_f32_batch { |
| 214 | +# struct clip_image_f32 * data; |
| 215 | +# size_t size; |
| 216 | +# }; |
| 217 | + |
| 218 | +# struct clip_image_u8 * make_clip_image_u8(); |
| 219 | +# struct clip_image_f32 * make_clip_image_f32(); |
| 220 | +# CLIP_API void clip_image_u8_free(clip_image_u8 * img); |
| 221 | +# CLIP_API void clip_image_f32_free(clip_image_f32 * img); |
| 222 | +# CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); |
| 223 | +# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ |
| 224 | +# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); |
| 225 | + |
| 226 | +# bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square); |
| 227 | +# bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); |
| 228 | + |
| 229 | +# bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs, |
| 230 | +# float * vec); |
| 231 | + |
| 232 | +# bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype); |
0 commit comments