Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e83ba3e

Browse filesBrowse files
authored
llama : add support for jina-reranker-v2 (ggml-org#13900)
1 parent 2b13162 commit e83ba3e
Copy full SHA for e83ba3e

File tree

Expand file treeCollapse file tree

5 files changed

+119
-72
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+119
-72
lines changed

‎convert_hf_to_gguf.py

Copy file name to clipboardExpand all lines: convert_hf_to_gguf.py
+85-35Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3782,44 +3782,93 @@ def _xlmroberta_set_vocab(self) -> None:
37823782
from sentencepiece import sentencepiece_model_pb2 as model
37833783

37843784
tokenizer_path = self.dir_model / 'sentencepiece.bpe.model'
3785+
3786+
tokenizer_json = {}
3787+
tokenizer_config_json = {}
37853788
if not tokenizer_path.is_file():
3786-
raise FileNotFoundError(f"File not found: {tokenizer_path}")
3789+
tokenizer_path = self.dir_model / 'tokenizer.json'
3790+
tokenizer_config_path = self.dir_model / 'tokenizer_config.json'
37873791

3788-
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
3789-
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
3790-
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
3792+
if not tokenizer_path.is_file():
3793+
raise FileNotFoundError(f"File not found: {tokenizer_path}")
37913794

3792-
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
3793-
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
3794-
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
3795+
from base64 import b64decode
3796+
from transformers import AutoTokenizer
3797+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
37953798

3796-
tokenizer = SentencePieceProcessor()
3797-
tokenizer.LoadFromFile(str(tokenizer_path))
3799+
with open(tokenizer_path, "r", encoding="utf-8") as fp:
3800+
tokenizer_json = json.load(fp)
37983801

3799-
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
3802+
if tokenizer_config_path.is_file():
3803+
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
3804+
tokenizer_config_json = json.load(fp)
3805+
3806+
add_prefix = tokenizer.add_prefix_space
3807+
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
3808+
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
3809+
3810+
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size)
3811+
else:
3812+
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
3813+
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
3814+
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
3815+
3816+
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
3817+
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
3818+
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
3819+
3820+
tokenizer = SentencePieceProcessor()
3821+
tokenizer.LoadFromFile(str(tokenizer_path))
3822+
3823+
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
38003824

38013825
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
38023826
scores: list[float] = [-10000.0] * vocab_size
38033827
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
38043828

3805-
for token_id in range(tokenizer.vocab_size()):
3806-
piece = tokenizer.IdToPiece(token_id)
3807-
text = piece.encode("utf-8")
3808-
score = tokenizer.GetScore(token_id)
3829+
if isinstance(tokenizer, SentencePieceProcessor):
3830+
for token_id in range(tokenizer.vocab_size()):
3831+
piece = tokenizer.IdToPiece(token_id)
3832+
text = piece.encode("utf-8")
3833+
score = tokenizer.GetScore(token_id)
38093834

3810-
toktype = SentencePieceTokenTypes.NORMAL
3811-
if tokenizer.IsUnknown(token_id):
3812-
toktype = SentencePieceTokenTypes.UNKNOWN
3813-
elif tokenizer.IsControl(token_id):
3814-
toktype = SentencePieceTokenTypes.CONTROL
3815-
elif tokenizer.IsUnused(token_id):
3816-
toktype = SentencePieceTokenTypes.UNUSED
3817-
elif tokenizer.IsByte(token_id):
3818-
toktype = SentencePieceTokenTypes.BYTE
3835+
toktype = SentencePieceTokenTypes.NORMAL
3836+
if tokenizer.IsUnknown(token_id):
3837+
toktype = SentencePieceTokenTypes.UNKNOWN
3838+
elif tokenizer.IsControl(token_id):
3839+
toktype = SentencePieceTokenTypes.CONTROL
3840+
elif tokenizer.IsUnused(token_id):
3841+
toktype = SentencePieceTokenTypes.UNUSED
3842+
elif tokenizer.IsByte(token_id):
3843+
toktype = SentencePieceTokenTypes.BYTE
38193844

3820-
tokens[token_id] = text
3821-
scores[token_id] = score
3822-
toktypes[token_id] = toktype
3845+
tokens[token_id] = text
3846+
scores[token_id] = score
3847+
toktypes[token_id] = toktype
3848+
else:
3849+
added_vocab = tokenizer.get_added_vocab()
3850+
unk_token = tokenizer_config_json.get("unk_token")
3851+
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
3852+
3853+
for token_id in range(vocab_size):
3854+
piece = tokenizer._convert_id_to_token(token_id)
3855+
text = piece.encode("utf-8")
3856+
score = tokenizer_json["model"]["vocab"][token_id][1]
3857+
3858+
toktype = SentencePieceTokenTypes.NORMAL
3859+
if token_id == unk_token_id:
3860+
toktype = SentencePieceTokenTypes.UNKNOWN
3861+
elif token_id in tokenizer.all_special_ids:
3862+
toktype = SentencePieceTokenTypes.CONTROL
3863+
elif token_id in added_vocab.values():
3864+
toktype = SentencePieceTokenTypes.USER_DEFINED
3865+
# No reliable way to detect this, but jina doesn't have any
3866+
# elif tokenizer.IsByte(token_id):
3867+
# toktype = SentencePieceTokenTypes.BYTE
3868+
3869+
tokens[token_id] = text
3870+
scores[token_id] = score
3871+
toktypes[token_id] = toktype
38233872

38243873
if vocab_size > len(tokens):
38253874
pad_count = vocab_size - len(tokens)
@@ -3829,15 +3878,16 @@ def _xlmroberta_set_vocab(self) -> None:
38293878
scores.append(-1000.0)
38303879
toktypes.append(SentencePieceTokenTypes.UNUSED)
38313880

3832-
# realign tokens (see HF tokenizer code)
3833-
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
3834-
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
3835-
toktypes = [
3836-
SentencePieceTokenTypes.CONTROL,
3837-
SentencePieceTokenTypes.CONTROL,
3838-
SentencePieceTokenTypes.CONTROL,
3839-
SentencePieceTokenTypes.UNKNOWN,
3840-
] + toktypes[3:-1]
3881+
if isinstance(tokenizer, SentencePieceProcessor):
3882+
# realign tokens (see HF tokenizer code)
3883+
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
3884+
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
3885+
toktypes = [
3886+
SentencePieceTokenTypes.CONTROL,
3887+
SentencePieceTokenTypes.CONTROL,
3888+
SentencePieceTokenTypes.CONTROL,
3889+
SentencePieceTokenTypes.UNKNOWN,
3890+
] + toktypes[3:-1]
38413891

38423892
self.gguf_writer.add_tokenizer_model("t5")
38433893
self.gguf_writer.add_tokenizer_pre("default")

‎gguf-py/gguf/constants.py

Copy file name to clipboardExpand all lines: gguf-py/gguf/constants.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,7 @@ class MODEL_TENSOR(IntEnum):
10361036
MODEL_TENSOR.POS_EMBD,
10371037
MODEL_TENSOR.OUTPUT_NORM,
10381038
MODEL_TENSOR.ATTN_OUT_NORM,
1039+
MODEL_TENSOR.ATTN_QKV,
10391040
MODEL_TENSOR.ATTN_Q,
10401041
MODEL_TENSOR.ATTN_K,
10411042
MODEL_TENSOR.ATTN_V,

‎gguf-py/gguf/tensor_mapping.py

Copy file name to clipboardExpand all lines: gguf-py/gguf/tensor_mapping.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class TensorNameMap:
157157
"h.{bid}.attn.c_attn", # gpt2
158158
"transformer.h.{bid}.mixer.Wqkv", # phi2
159159
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
160+
"encoder.layers.{bid}.mixer.Wqkv", # jina
160161
"model.layers.{bid}.self_attn.qkv_proj", # phi3
161162
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
162163
"transformer.layers.{bid}.attn.qkv_proj", # openelm
@@ -224,6 +225,7 @@ class TensorNameMap:
224225
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
225226
"model.layers.{bid}.attention.wo", # internlm2
226227
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
228+
"encoder.layers.{bid}.mixer.out_proj", # jina
227229
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
228230
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
229231
"encoder.layers.{bid}.self_attention.dense", # chatglm

‎src/llama-arch.cpp

Copy file name to clipboardExpand all lines: src/llama-arch.cpp
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
450450
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
451451
{ LLM_TENSOR_POS_EMBD, "position_embd" },
452452
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
453+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
453454
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
454455
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
455456
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },

‎src/llama-model.cpp

Copy file name to clipboardExpand all lines: src/llama-model.cpp
+30-37Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,7 +2132,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
21322132
for (int i = 0; i < n_layer; ++i) {
21332133
auto & layer = layers[i];
21342134

2135-
if (arch == LLM_ARCH_BERT) {
2135+
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2136+
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
2137+
2138+
if (!layer.wqkv) {
21362139
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
21372140
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
21382141

@@ -2141,12 +2144,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
21412144

21422145
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
21432146
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
2144-
} else {
2145-
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
2146-
}
2147-
2148-
if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
2149-
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
21502147
}
21512148

21522149
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
@@ -5910,48 +5907,44 @@ struct llm_build_bert : public llm_graph_context {
59105907
ggml_tensor * Vcur;
59115908

59125909
// self-attention
5913-
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
5914-
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
5915-
5916-
if (model.layers[il].attn_q_norm) {
5917-
Qcur = build_norm(Qcur,
5918-
model.layers[il].attn_q_norm,
5919-
model.layers[il].attn_q_norm_b,
5920-
LLM_NORM, il);
5921-
}
5922-
5923-
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5924-
5925-
if (model.layers[il].attn_k_norm) {
5926-
Kcur = build_norm(Kcur,
5927-
model.layers[il].attn_k_norm,
5928-
model.layers[il].attn_k_norm_b,
5929-
LLM_NORM, il);
5930-
}
5931-
5932-
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5933-
5934-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5935-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5936-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5937-
} else {
5938-
// compute Q and K and RoPE them
5910+
if (model.layers[il].wqkv) {
59395911
cur = build_lora_mm(model.layers[il].wqkv, cur);
59405912
cb(cur, "wqkv", il);
59415913

5942-
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5914+
if (model.layers[il].bqkv) {
59435915
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
59445916
cb(cur, "bqkv", il);
59455917
}
59465918

59475919
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
59485920
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
59495921
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
5922+
} else {
5923+
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
5924+
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
5925+
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
5926+
}
59505927

5951-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5952-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5953-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
5928+
if (model.layers[il].attn_q_norm) {
5929+
Qcur = build_norm(Qcur,
5930+
model.layers[il].attn_q_norm,
5931+
model.layers[il].attn_q_norm_b,
5932+
LLM_NORM, il);
5933+
}
5934+
5935+
if (model.layers[il].attn_k_norm) {
5936+
Kcur = build_norm(Kcur,
5937+
model.layers[il].attn_k_norm,
5938+
model.layers[il].attn_k_norm_b,
5939+
LLM_NORM, il);
5940+
}
5941+
5942+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5943+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
5944+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
59545945

5946+
// RoPE
5947+
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
59555948
Qcur = ggml_rope_ext(
59565949
ctx0, Qcur, inp_pos, nullptr,
59575950
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.