Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d17a809

Browse filesBrowse files
authored
llama : support multiple classifier outputs and labels (#13940)
1 parent 1caae7f commit d17a809
Copy full SHA for d17a809

File tree

Expand file treeCollapse file tree

6 files changed

+101
-24
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+101
-24
lines changed

‎examples/embedding/embedding.cpp

Copy file name to clipboardExpand all lines: examples/embedding/embedding.cpp
+17-2Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,24 @@ int main(int argc, char ** argv) {
236236
LOG("\n");
237237
}
238238
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
239+
const uint32_t n_cls_out = llama_model_n_cls_out(model);
240+
std::vector<std::string> cls_out_labels;
241+
242+
for (uint32_t i = 0; i < n_cls_out; i++) {
243+
const char * label = llama_model_cls_label(model, i);
244+
const std::string label_i(label == nullptr ? "" : label);
245+
cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
246+
}
247+
239248
for (int j = 0; j < n_embd_count; j++) {
240-
// NOTE: if you change this log - update the tests in ci/run.sh
241-
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
249+
for (uint32_t i = 0; i < n_cls_out; i++) {
250+
// NOTE: if you change this log - update the tests in ci/run.sh
251+
if (n_cls_out == 1) {
252+
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
253+
} else {
254+
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
255+
}
256+
}
242257
}
243258
} else {
244259
// print the first part of the embeddings or for a single prompt, the full embedding

‎include/llama.h

Copy file name to clipboardExpand all lines: include/llama.h
+8-1Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,13 @@ extern "C" {
514514
// Get the model's RoPE frequency scaling factor
515515
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
516516

517+
// Returns the number of classifier outputs (only valid for classifier models)
518+
// Undefined behavior for non-classifier models
519+
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
520+
521+
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
522+
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
523+
517524
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
518525

519526
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
@@ -992,7 +999,7 @@ extern "C" {
992999

9931000
// Get the embeddings for a sequence id
9941001
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
995-
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
1002+
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
9961003
// otherwise: float[n_embd] (1-dimensional)
9971004
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
9981005

‎src/llama-context.cpp

Copy file name to clipboardExpand all lines: src/llama-context.cpp
+4-3Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -839,16 +839,17 @@ int llama_context::encode(llama_batch & inp_batch) {
839839
} break;
840840
case LLAMA_POOLING_TYPE_RANK:
841841
{
842-
// extract the rerank score - a single float per sequence
842+
// extract the rerank score - n_cls_out floats per sequence
843843
auto & embd_seq_out = embd_seq;
844+
const uint32_t n_cls_out = hparams.n_cls_out;
844845

845846
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
846847
const llama_seq_id seq_id = ubatch.seq_id[s][0];
847848
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
848849
continue;
849850
}
850-
embd_seq_out[seq_id].resize(1);
851-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
851+
embd_seq_out[seq_id].resize(n_cls_out);
852+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
852853
}
853854
} break;
854855
case LLAMA_POOLING_TYPE_UNSPECIFIED:

‎src/llama-model-loader.cpp

Copy file name to clipboardExpand all lines: src/llama-model-loader.cpp
+42-17Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -288,61 +288,84 @@ namespace GGUFMeta {
288288

289289
template<typename T>
290290
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
291-
const int kid = gguf_find_key(meta.get(), key.c_str());
291+
const gguf_context * ctx = meta.get();
292+
const int kid = gguf_find_key(ctx, key.c_str());
292293

293-
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
294+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
294295
if (required) {
295296
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
296297
}
297298
return false;
298299
}
299300

300301
struct GGUFMeta::ArrayInfo arr_info =
301-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
302+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
302303

303304
switch (arr_info.gt) {
304305
case GGUF_TYPE_UINT32:
305-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
306-
(std::is_same<T, uint32_t>::value)); break;
307-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
306+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
307+
(std::is_same<T, uint32_t>::value)); break;
308+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
309+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
308310
default:
309-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
311+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
310312
}
311313

312-
result.resize(arr_info.length);
313-
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
314+
if constexpr (std::is_same<T, std::string>::value) {
315+
const size_t n_items = gguf_get_arr_n(ctx, kid);
316+
result.clear();
317+
318+
for (size_t i = 0; i < n_items; i++) {
319+
const T value = gguf_get_arr_str(ctx, kid, i);
320+
result.emplace_back(value);
321+
}
322+
} else {
323+
result.resize(arr_info.length);
324+
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
325+
}
314326

315327
return true;
316328
}
317329

318330
template<typename T, size_t N_MAX>
319331
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
320-
const int kid = gguf_find_key(meta.get(), key.c_str());
332+
const gguf_context * ctx = meta.get();
333+
const int kid = gguf_find_key(ctx, key.c_str());
321334

322-
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
335+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
323336
if (required) {
324337
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
325338
}
326339
return false;
327340
}
328341

329342
struct GGUFMeta::ArrayInfo arr_info =
330-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
343+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
331344

332345
switch (arr_info.gt) {
333346
case GGUF_TYPE_UINT32:
334-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
335-
(std::is_same<T, uint32_t>::value)); break;
336-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
347+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
348+
(std::is_same<T, uint32_t>::value)); break;
349+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
350+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
337351
default:
338-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
352+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
339353
}
340354

341355
if (arr_info.length > N_MAX) {
342356
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
343357
}
344358

345-
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
359+
if constexpr (std::is_same<T, std::string>::value) {
360+
const size_t n_items = gguf_get_arr_n(ctx, kid);
361+
362+
for (size_t i = 0; i < n_items; i++) {
363+
const T value = gguf_get_arr_str(ctx, kid, i);
364+
result[i] = value;
365+
}
366+
} else {
367+
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
368+
}
346369

347370
return true;
348371
}
@@ -352,6 +375,8 @@ namespace GGUFMeta {
352375
return get_arr(llm_kv(kid), result, required);
353376
}
354377

378+
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
379+
355380
template<typename T>
356381
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
357382
auto it = kv_overrides.find(key);

‎src/llama-model.cpp

Copy file name to clipboardExpand all lines: src/llama-model.cpp
+27-1Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
543543
uint32_t n_vocab = 0;
544544
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
545545

546+
// for classifier models
547+
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
548+
if (!classifier_labels.empty()) {
549+
hparams.n_cls_out = classifier_labels.size();
550+
}
551+
546552
// arch-specific KVs
547553
switch (arch) {
548554
case LLM_ARCH_LLAMA:
@@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
686692
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
687693
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
688694
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
689-
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
690695

691696
switch (hparams.n_layer) {
692697
case 3:
@@ -4362,6 +4367,15 @@ void llama_model::print_info() const {
43624367
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
43634368
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
43644369
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
4370+
4371+
if (!classifier_labels.empty()) {
4372+
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
4373+
4374+
size_t i = 0;
4375+
for (auto label : classifier_labels) {
4376+
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
4377+
}
4378+
}
43654379
}
43664380

43674381
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@@ -13602,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
1360213616
return model->hparams.n_swa;
1360313617
}
1360413618

13619+
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
13620+
return model->hparams.n_cls_out;
13621+
}
13622+
13623+
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
13624+
if (i < model->classifier_labels.size()) {
13625+
return model->classifier_labels[i].c_str();
13626+
}
13627+
13628+
return nullptr;
13629+
}
13630+
1360513631
// deprecated
1360613632
int32_t llama_n_ctx_train(const llama_model * model) {
1360713633
return llama_model_n_ctx_train(model);

‎src/llama-model.h

Copy file name to clipboardExpand all lines: src/llama-model.h
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ struct llama_model {
329329
llama_hparams hparams = {};
330330
llama_vocab vocab;
331331

332+
// for classifier models
333+
std::vector<std::string> classifier_labels;
334+
332335
struct ggml_tensor * tok_embd = nullptr;
333336
struct ggml_tensor * type_embd = nullptr;
334337
struct ggml_tensor * pos_embd = nullptr;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.