Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 62a45d1

Browse filesBrowse files
committed
rerank : cleanup + comments
1 parent 6916ed1 commit 62a45d1
Copy full SHA for 62a45d1

File tree

Expand file treeCollapse file tree

5 files changed

+27
-14
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+27
-14
lines changed

‎examples/embedding/embedding.cpp

Copy file name to clipboardExpand all lines: examples/embedding/embedding.cpp
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ int main(int argc, char ** argv) {
236236
}
237237
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
238238
for (int j = 0; j < n_embd_count; j++) {
239-
LOG("rank score %d: %8.3f\n", j, emb[j * n_embd]);
239+
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
240240
}
241241
} else {
242242
// print the first part of the embeddings or for a single prompt, the full embedding

‎examples/server/server.cpp

Copy file name to clipboardExpand all lines: examples/server/server.cpp
+11-5Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,7 +1419,7 @@ struct server_context {
14191419
queue_results.send(res);
14201420
}
14211421

1422-
void send_rank(const server_slot & slot, const llama_batch & batch) {
1422+
void send_rerank(const server_slot & slot, const llama_batch & batch) {
14231423
server_task_result res;
14241424
res.id = slot.id_task;
14251425
res.error = false;
@@ -1440,19 +1440,19 @@ struct server_context {
14401440

14411441
res.data = json {
14421442
{"index", slot.index},
1443-
{"rank", -1e6},
1443+
{"score", -1e6},
14441444
};
14451445

14461446
continue;
14471447
}
14481448

14491449
res.data = json {
14501450
{"index", slot.index},
1451-
{"rank", embd[0]},
1451+
{"score", embd[0]},
14521452
};
14531453
}
14541454

1455-
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
1455+
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
14561456

14571457
queue_results.send(res);
14581458
}
@@ -1493,6 +1493,9 @@ struct server_context {
14931493
else if (prompt.is_array()) {
14941494
std::vector<json> prompts = prompt;
14951495
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496+
// prompts[0] is the question
1497+
// the rest are the answers/documents
1498+
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
14961499
for (size_t i = 1; i < prompts.size(); i++) {
14971500
json qd;
14981501
qd.push_back(prompts[0]);
@@ -1501,6 +1504,7 @@ struct server_context {
15011504
create_task(data, true, qd);
15021505
}
15031506
} else {
1507+
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
15041508
for (size_t i = 0; i < prompts.size(); i++) {
15051509
const auto & e = prompts[i];
15061510
if (e.is_string() || json_is_array_of_numbers(e)) {
@@ -1965,6 +1969,7 @@ struct server_context {
19651969
// track if this is an embedding or non-embedding batch
19661970
// if we've added sampled tokens above, we are in non-embedding mode
19671971
// -1: none, 0: non-embedding, 1: embedding
1972+
// TODO: make enum
19681973
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
19691974

19701975
// next, batch any pending prompts without exceeding n_batch
@@ -2133,6 +2138,7 @@ struct server_context {
21332138
slot.n_prompt_tokens_processed = 0;
21342139
}
21352140

2141+
// non-causal tasks require to fit the entire prompt in the physical batch
21362142
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
21372143
// cannot fit the prompt in the current batch - will try next iter
21382144
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
@@ -2318,7 +2324,7 @@ struct server_context {
23182324
}
23192325

23202326
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2321-
send_rank(slot, batch_view);
2327+
send_rerank(slot, batch_view);
23222328
slot.release();
23232329
slot.i_batch = -1;
23242330
continue; // continue loop of slots

‎examples/server/utils.hpp

Copy file name to clipboardExpand all lines: examples/server/utils.hpp
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ static json format_response_rerank(const json & request, const json & ranks) {
553553
for (const auto & rank : ranks) {
554554
data.push_back(json{
555555
{"index", i++},
556-
{"relevance_score", json_value(rank, "rank", 0.0)},
556+
{"relevance_score", json_value(rank, "score", 0.0)},
557557
});
558558
}
559559

‎include/llama.h

Copy file name to clipboardExpand all lines: include/llama.h
+6-5Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ extern "C" {
192192
LLAMA_POOLING_TYPE_MEAN = 1,
193193
LLAMA_POOLING_TYPE_CLS = 2,
194194
LLAMA_POOLING_TYPE_LAST = 3,
195-
LLAMA_POOLING_TYPE_RANK = 4,
195+
LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
196196
};
197197

198198
enum llama_attention_type {
@@ -202,9 +202,9 @@ extern "C" {
202202
};
203203

204204
enum llama_split_mode {
205-
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
206-
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
207-
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
205+
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
206+
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
207+
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
208208
};
209209

210210
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
@@ -872,7 +872,8 @@ extern "C" {
872872

873873
// Get the embeddings for a sequence id
874874
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
875-
// shape: [n_embd] (1-dimensional)
875+
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
876+
// otherwise: float[n_embd] (1-dimensional)
876877
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
877878

878879
//

‎src/llama.cpp

Copy file name to clipboardExpand all lines: src/llama.cpp
+8-2Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17009,7 +17009,7 @@ static int llama_decode_internal(
1700917009
} break;
1701017010
case LLAMA_POOLING_TYPE_RANK:
1701117011
{
17012-
// extract the rank score - a single float per sequence
17012+
// extract the rerank score - a single float per sequence
1701317013
auto & embd_seq_out = lctx.embd_seq;
1701417014

1701517015
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
@@ -17211,7 +17211,6 @@ static int llama_encode_internal(
1721117211
case LLAMA_POOLING_TYPE_MEAN:
1721217212
case LLAMA_POOLING_TYPE_CLS:
1721317213
case LLAMA_POOLING_TYPE_LAST:
17214-
case LLAMA_POOLING_TYPE_RANK:
1721517214
{
1721617215
// extract sequence embeddings
1721717216
auto & embd_seq_out = lctx.embd_seq;
@@ -17228,6 +17227,13 @@ static int llama_encode_internal(
1722817227
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1722917228
}
1723017229
} break;
17230+
case LLAMA_POOLING_TYPE_RANK:
17231+
{
17232+
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
17233+
// wait for an encoder model that requires this pooling type in order to test it
17234+
// https://github.com/ggerganov/llama.cpp/pull/9510
17235+
GGML_ABORT("RANK pooling not implemented yet");
17236+
}
1723117237
case LLAMA_POOLING_TYPE_UNSPECIFIED:
1723217238
{
1723317239
GGML_ABORT("unknown pooling type");

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.