Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 05f2efd

Browse filesBrowse files
authored
Expose zero-shot labeling to Python (#73)
* rm __pycache__ * gitignore __pycache__ * gitignore dist * Upd usearch in image-search example * Upd usearch in image-search example * Implement ZSL in clip lib * Use new ZSL API in examples * Expose ZSL in Python * Upd readme in Python bindings * Bump version in Python bindings
1 parent c9c02cb commit 05f2efd
Copy full SHA for 05f2efd

File tree

Expand file treeCollapse file tree

9 files changed

+156
-62
lines changed
Filter options
Expand file treeCollapse file tree

9 files changed

+156
-62
lines changed

‎clip.cpp

Copy file name to clipboardExpand all lines: clip.cpp
+31Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,37 @@ bool softmax_with_sorting(float * arr, const int length, float * sorted_scores,
15181518
return true;
15191519
}
15201520

1521+
bool clip_zero_shot_label_image(struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8 * input_img,
1522+
const char ** labels, const size_t n_labels, float * scores, int * indices) {
1523+
// load the image
1524+
clip_image_f32 img_res;
1525+
1526+
const int vec_dim = clip_get_vision_hparams(ctx)->projection_dim;
1527+
1528+
clip_image_preprocess(ctx, input_img, &img_res);
1529+
1530+
float img_vec[vec_dim];
1531+
if (!clip_image_encode(ctx, n_threads, &img_res, img_vec, false)) {
1532+
return false;
1533+
}
1534+
1535+
// encode texts and compute similarities
1536+
float txt_vec[vec_dim];
1537+
float similarities[n_labels];
1538+
1539+
for (int i = 0; i < n_labels; i++) {
1540+
const auto & text = labels[i];
1541+
auto tokens = clip_tokenize(ctx, text);
1542+
clip_text_encode(ctx, n_threads, &tokens, txt_vec, false);
1543+
similarities[i] = clip_similarity_score(img_vec, txt_vec, vec_dim);
1544+
}
1545+
1546+
// apply softmax and sort scores
1547+
softmax_with_sorting(similarities, n_labels, scores, indices);
1548+
1549+
return true;
1550+
}
1551+
15211552
bool image_normalize(const clip_image_u8 * img, clip_image_f32 * res) {
15221553
if (img->nx != 224 || img->ny != 224) {
15231554
printf("%s: long input shape: %d x %d\n", __func__, img->nx, img->ny);

‎clip.h

Copy file name to clipboardExpand all lines: clip.h
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ bool clip_compare_text_and_image(const struct clip_ctx * ctx, const int n_thread
9898
const struct clip_image_u8 * image, float * score);
9999
float clip_similarity_score(const float * vec1, const float * vec2, const int vec_dim);
100100
bool softmax_with_sorting(float * arr, const int length, float * sorted_scores, int * indices);
101+
bool clip_zero_shot_label_image(struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8 * input_img,
102+
const char ** labels, const size_t n_labels, float * scores, int * indices);
101103

102104
#ifdef __cplusplus
103105
}

‎examples/common-clip.h

Copy file name to clipboardExpand all lines: examples/common-clip.h
+13-4Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,22 @@ std::map<std::string, std::vector<std::string>> get_dir_keyed_files(const std::s
1515

1616
bool is_image_file_extension(const std::string & path);
1717

18-
struct app_params {
19-
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
18+
#include <algorithm>
19+
#include <string>
20+
#include <vector>
2021

21-
std::string model = "models/ggml-model-f16.bin";
22+
struct app_params {
23+
int32_t n_threads;
24+
std::string model;
2225
std::vector<std::string> image_paths;
2326
std::vector<std::string> texts;
24-
int verbose = 1;
27+
int verbose;
28+
29+
app_params()
30+
: n_threads(std::min(4, static_cast<int32_t>(std::thread::hardware_concurrency()))), model("models/ggml-model-f16.bin"),
31+
verbose(1) {
32+
// Initialize other fields if needed
33+
}
2534
};
2635

2736
bool app_params_parse(int argc, char ** argv, app_params & params);

‎examples/image-search/CMakeLists.txt

Copy file name to clipboardExpand all lines: examples/image-search/CMakeLists.txt
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ set(CXX_STANDARD_REQUIRED ON)
55
include(FetchContent)
66
FetchContent_Declare(usearch
77
GIT_REPOSITORY https://github.com/unum-cloud/usearch.git
8-
GIT_TAG v0.20.0
8+
GIT_TAG v2.5.0
99
)
1010
FetchContent_MakeAvailable(usearch)
1111

‎examples/python_bindings/README.md

Copy file name to clipboardExpand all lines: examples/python_bindings/README.md
+29-14Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,20 @@ def compare_text_and_image(
159159
- `image_path` (str): The path to the image file for comparison.
160160
- `n_threads` (int, optional): The number of CPU threads to use for encoding (default is the number of CPU cores).
161161

162-
#### 8. `__del__`
162+
## 8. `zero_shot_label_image`
163+
164+
```python
165+
def zero_shot_label_image(
166+
self, image_path: str, labels: List[str], n_threads: int = os.cpu_count()
167+
) -> Tuple[List[float], List[int]]:
168+
```
169+
170+
- **Description**: Zero-shot labels an image with given candidate labels, returning a tuple of sorted scores and indices.
171+
- `image_path` (str): The path to the image file to be labelled.
172+
- `labels` (List[str]): A list of candidate labels to be scored.
173+
- `n_threads` (int, optional): The number of CPU threads to use for encoding (default is the number of CPU cores).
174+
175+
#### 9. `__del__`
163176

164177
```python
165178
def __del__(self):
@@ -175,17 +188,19 @@ A basic example can be found in the [clip.cpp examples](https://github.com/monat
175188

176189
```
177190
python example_main.py --help
178-
usage: clip [-h] -m MODEL [-v VERBOSITY] -t TEXT -i IMAGE
179-
180-
optional arguments:
181-
-h, --help show this help message and exit
182-
-m MODEL, --model MODEL
183-
path to GGML file
184-
-v VERBOSITY, --verbosity VERBOSITY
185-
Level of verbosity. 0 = minimum, 2 = maximum
186-
-t TEXT, --text TEXT text to encode
187-
-i IMAGE, --image IMAGE
188-
path to an image file
189-
```
191+
usage: clip [-h] -m MODEL [-fn FILENAME] [-v VERBOSITY] -t TEXT [TEXT ...] -i IMAGE
192+
193+
optional arguments:
194+
-h, --help show this help message and exit
195+
-m MODEL, --model MODEL
196+
path to GGML file or repo_id
197+
-fn FILENAME, --filename FILENAME
198+
path to GGML file in the Hugging face repo
199+
-v VERBOSITY, --verbosity VERBOSITY
200+
Level of verbosity. 0 = minimum, 2 = maximum
201+
-t TEXT [TEXT ...], --text TEXT [TEXT ...]
202+
text to encode. Multiple values allowed. In this case, apply zero-shot labeling
203+
-i IMAGE, --image IMAGE
204+
path to an image file
205+
``````
190206
191-
Bindings to the DLL are implemented in `clip_cpp/clip.py` and

‎examples/python_bindings/clip_cpp/clip.py

Copy file name to clipboardExpand all lines: examples/python_bindings/clip_cpp/clip.py
+41-1Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import platform
44
from glob import glob
55
from pathlib import Path
6-
from typing import List, Dict, Any, Optional
6+
from typing import List, Dict, Any, Optional, Tuple
77

88
from .file_download import ModelInfo, model_download, model_info
99

@@ -167,6 +167,18 @@ class ClipContext(ctypes.Structure):
167167
]
168168
clip_similarity_score.restype = ctypes.c_float
169169

170+
clip_zero_shot_label_image = clip_lib.clip_zero_shot_label_image
171+
clip_zero_shot_label_image.argtypes = [
172+
ctypes.POINTER(ClipContext),
173+
ctypes.c_int,
174+
ctypes.POINTER(ClipImageU8),
175+
ctypes.POINTER(ctypes.c_char_p),
176+
ctypes.c_ssize_t,
177+
ctypes.POINTER(ctypes.c_float),
178+
ctypes.POINTER(ctypes.c_int),
179+
]
180+
clip_zero_shot_label_image.restype = ctypes.c_bool
181+
170182
softmax_with_sorting = clip_lib.softmax_with_sorting
171183
softmax_with_sorting.argtypes = [
172184
ctypes.POINTER(ctypes.c_float),
@@ -369,6 +381,34 @@ def compare_text_and_image(
369381

370382
return score.value
371383

384+
def zero_shot_label_image(
385+
self, image_path: str, labels: List[str], n_threads: int = os.cpu_count()
386+
) -> Tuple[List[float], List[int]]:
387+
n_labels = len(labels)
388+
if n_labels < 2:
389+
raise ValueError(
390+
"You must pass at least 2 labels for zero-shot image labeling"
391+
)
392+
393+
labels = (ctypes.c_char_p * n_labels)(
394+
*[ctypes.c_char_p(label.encode("utf8")) for label in labels]
395+
)
396+
image_ptr = make_clip_image_u8()
397+
if not clip_image_load_from_file(image_path.encode("utf8"), image_ptr):
398+
raise RuntimeError(f"Could not load image {image_path}")
399+
400+
scores = (ctypes.c_float * n_labels)()
401+
indices = (ctypes.c_int * n_labels)()
402+
if not clip_zero_shot_label_image(
403+
self.ctx, n_threads, image_ptr, labels, n_labels, scores, indices
404+
):
405+
print("function called")
406+
raise RuntimeError("Could not zero-shot label image")
407+
408+
return [scores[i] for i in range(n_labels)], [
409+
indices[i] for i in range(n_labels)
410+
]
411+
372412
def __del__(self):
373413
if hasattr(self, "ctx"):
374414
clip_free(self.ctx)

‎examples/python_bindings/example_main.py

Copy file name to clipboardExpand all lines: examples/python_bindings/example_main.py
+25-14Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,39 @@
55
if __name__ == "__main__":
66
ap = argparse.ArgumentParser(prog="clip")
77
ap.add_argument("-m", "--model", help="path to GGML file or repo_id", required=True)
8-
ap.add_argument("-fn", "--filename", help="path to GGML file in the Hugging face repo", required=False)
8+
ap.add_argument(
9+
"-fn",
10+
"--filename",
11+
help="path to GGML file in the Hugging face repo",
12+
required=False,
13+
)
914
ap.add_argument(
1015
"-v",
1116
"--verbosity",
1217
type=int,
1318
help="Level of verbosity. 0 = minimum, 2 = maximum",
1419
default=0,
1520
)
16-
ap.add_argument("-t", "--text", help="text to encode", required=True)
21+
ap.add_argument(
22+
"-t",
23+
"--text",
24+
help="text to encode. Multiple values allowed. In this case, apply zero-shot labeling",
25+
nargs="+",
26+
type=str,
27+
required=True,
28+
)
1729
ap.add_argument("-i", "--image", help="path to an image file", required=True)
1830
args = ap.parse_args()
1931

2032
clip = Clip(args.model, args.verbosity)
21-
22-
tokens = clip.tokenize(args.text)
23-
text_embed = clip.encode_text(tokens)
24-
25-
image_embed = clip.load_preprocess_encode_image(args.image)
26-
27-
score = clip.calculate_similarity(text_embed, image_embed)
28-
29-
# Alternatively, you can just do:
30-
# score = clip.compare_text_and_image(text, image_path)
31-
32-
print(f"Similarity score: {score}")
33+
if len(args.text) == 1:
34+
score = clip.compare_text_and_image(args.text[0], args.image)
35+
36+
print(f"Similarity score: {score}")
37+
else:
38+
sorted_scores, sorted_indices = clip.zero_shot_label_image(
39+
args.image, args.text
40+
)
41+
for ind, score in zip(sorted_indices, sorted_scores):
42+
label = args.text[ind]
43+
print(f"{label}: {score:.4f}")

‎examples/python_bindings/pyproject.toml

Copy file name to clipboardExpand all lines: examples/python_bindings/pyproject.toml
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "clip_cpp"
3-
version = "0.4.1"
3+
version = "0.4.2"
44
description = "CLIP inference with no big dependencies as PyTorch, TensorFlow, Numpy"
55
authors = ["Yusuf Sarıgöz <yusufsarigoz@gmail.com>"]
66
packages = [{ include = "clip_cpp" }]

‎examples/zsl.cpp

Copy file name to clipboardExpand all lines: examples/zsl.cpp
+13-27Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@ int main(int argc, char ** argv) {
1010
return 1;
1111
}
1212

13-
int n_labels = params.texts.size();
13+
const size_t n_labels = params.texts.size();
1414
if (n_labels < 2) {
1515
printf("%s: You must specify at least 2 texts for zero-shot labeling\n", __func__);
1616
}
1717

18+
const char * labels[n_labels];
19+
for (size_t i = 0; i < n_labels; ++i) {
20+
labels[i] = params.texts[i].c_str();
21+
}
22+
1823
auto ctx = clip_model_load(params.model.c_str(), params.verbose);
1924
if (!ctx) {
2025
printf("%s: Unable to load model from %s", __func__, params.model.c_str());
@@ -23,40 +28,21 @@ int main(int argc, char ** argv) {
2328

2429
// load the image
2530
const auto & img_path = params.image_paths[0].c_str();
26-
clip_image_u8 img0;
27-
clip_image_f32 img_res;
28-
if (!clip_image_load_from_file(img_path, &img0)) {
31+
clip_image_u8 input_img;
32+
if (!clip_image_load_from_file(img_path, &input_img)) {
2933
fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, img_path);
3034
return 1;
3135
}
3236

33-
const int vec_dim = clip_get_vision_hparams(ctx)->projection_dim;
34-
35-
clip_image_preprocess(ctx, &img0, &img_res);
36-
37-
float img_vec[vec_dim];
38-
if (!clip_image_encode(ctx, params.n_threads, &img_res, img_vec, false)) {
37+
float sorted_scores[n_labels];
38+
int sorted_indices[n_labels];
39+
if (!clip_zero_shot_label_image(ctx, params.n_threads, &input_img, labels, n_labels, sorted_scores, sorted_indices)) {
40+
fprintf(stderr, "Unable to apply ZSL\n");
3941
return 1;
4042
}
4143

42-
// encode texts and compute similarities
43-
float txt_vec[vec_dim];
44-
float similarities[n_labels];
45-
46-
for (int i = 0; i < n_labels; i++) {
47-
const auto & text = params.texts[i].c_str();
48-
auto tokens = clip_tokenize(ctx, text);
49-
clip_text_encode(ctx, params.n_threads, &tokens, txt_vec, false);
50-
similarities[i] = clip_similarity_score(img_vec, txt_vec, vec_dim);
51-
}
52-
53-
// apply softmax and sort scores
54-
float sorted_scores[n_labels];
55-
int indices[n_labels];
56-
softmax_with_sorting(similarities, n_labels, sorted_scores, indices);
57-
5844
for (int i = 0; i < n_labels; i++) {
59-
auto label = params.texts[indices[i]].c_str();
45+
auto label = labels[sorted_indices[i]];
6046
float score = sorted_scores[i];
6147
printf("%s = %1.4f\n", label, score);
6248
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.