Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2891c8a

Browse filesBrowse files
iamlemeccebtenzzreggerganov
authored
Add support for BERT embedding models (#5423)
* BERT model graph construction (build_bert) * WordPiece tokenizer (llm_tokenize_wpm) * Add flag for non-causal attention models * Allow for models that only output embeddings * Support conversion of BERT models to GGUF * Based on prior work by @xyzhang626 and @skeskinen --------- Co-authored-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 97a3365 commit 2891c8a
Copy full SHA for 2891c8a

File tree

8 files changed

+616
-52
lines changed
Filter options

8 files changed

+616
-52
lines changed

‎.flake8

Copy file name to clipboard
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[flake8]
22
max-line-length = 125
3+
ignore = W503

‎convert-hf-to-gguf.py

Copy file name to clipboardExpand all lines: convert-hf-to-gguf.py
+94Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def from_model_architecture(model_architecture):
209209
return InternLM2Model
210210
if model_architecture == "MiniCPMForCausalLM":
211211
return MiniCPMModel
212+
if model_architecture == "BertModel":
213+
return BertModel
212214
return Model
213215

214216
def _is_model_safetensors(self) -> bool:
@@ -264,6 +266,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
264266
return gguf.MODEL_ARCH.INTERNLM2
265267
if arch == "MiniCPMForCausalLM":
266268
return gguf.MODEL_ARCH.MINICPM
269+
if arch == "BertModel":
270+
return gguf.MODEL_ARCH.BERT
267271

268272
raise NotImplementedError(f'Architecture "{arch}" not supported!')
269273

@@ -1629,6 +1633,96 @@ def write_tensors(self):
16291633
self.post_write_tensors(tensor_map, name, data_torch)
16301634

16311635

1636+
class BertModel(Model):
1637+
def __init__(self, *args, **kwargs):
1638+
super().__init__(*args, **kwargs)
1639+
self.block_count = self.hparams["num_hidden_layers"]
1640+
1641+
def set_gguf_parameters(self):
1642+
# TODO(cebtenzzre): merge with parent class
1643+
self.gguf_writer.add_name(self.dir_model.name)
1644+
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
1645+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
1646+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
1647+
self.gguf_writer.add_block_count(self.block_count)
1648+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
1649+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
1650+
self.gguf_writer.add_causal_attention(False)
1651+
self.gguf_writer.add_file_type(self.ftype)
1652+
1653+
def set_vocab(self):
1654+
path = self.dir_model
1655+
added_tokens_path = self.dir_model if self.dir_model.exists() else None
1656+
1657+
# use huggingface vocab to get all tokens
1658+
vocab = HfVocab(path, added_tokens_path)
1659+
tokens, scores, toktypes = zip(*vocab.all_tokens())
1660+
assert len(tokens) == vocab.vocab_size
1661+
1662+
# we need this to validate the size of the token_type embeddings
1663+
# though currently we are passing all zeros to the token_type embeddings
1664+
n_token_types = len(set(toktypes))
1665+
self.gguf_writer.add_token_type_count(n_token_types)
1666+
1667+
# convert to phantom space vocab
1668+
def phantom(tok, typ):
1669+
if tok.startswith(b"[") and tok.endswith(b"]"):
1670+
return tok
1671+
if tok.startswith(b"##"):
1672+
return tok[2:]
1673+
return b"\xe2\x96\x81" + tok
1674+
tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)]
1675+
1676+
# set up bos and eos tokens (cls and sep)
1677+
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
1678+
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
1679+
1680+
# add vocab to gguf
1681+
self.gguf_writer.add_tokenizer_model("bert")
1682+
self.gguf_writer.add_token_list(tokens)
1683+
self.gguf_writer.add_token_scores(scores)
1684+
self.gguf_writer.add_token_types(toktypes)
1685+
1686+
# handle special tokens
1687+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
1688+
special_vocab.add_to_gguf(self.gguf_writer)
1689+
1690+
def write_tensors(self):
1691+
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
1692+
tensors = dict(self.get_tensors())
1693+
for name, data_torch in tensors.items():
1694+
# we are only using BERT for embeddings so we don't need the pooling layer
1695+
if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"):
1696+
continue # we don't need these
1697+
1698+
# map tensor names
1699+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1700+
if new_name is None:
1701+
print(f"Can not map tensor {name!r}")
1702+
sys.exit()
1703+
1704+
data = data_torch.squeeze().numpy()
1705+
n_dims = len(data.shape)
1706+
new_dtype: type[np.floating[Any]]
1707+
1708+
if (
1709+
self.ftype == 1 and name.endswith(".weight") and n_dims == 2
1710+
and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32
1711+
):
1712+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1713+
new_dtype = np.float16
1714+
else:
1715+
# if f32 desired, convert any float16 to float32
1716+
new_dtype = np.float32
1717+
1718+
print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}")
1719+
1720+
if data.dtype != new_dtype:
1721+
data = data.astype(new_dtype)
1722+
1723+
self.gguf_writer.add_tensor(new_name, data)
1724+
1725+
16321726
###### CONVERSION LOGIC ######
16331727

16341728

‎examples/embedding/embedding.cpp

Copy file name to clipboardExpand all lines: examples/embedding/embedding.cpp
+11-1Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,17 @@ int main(int argc, char ** argv) {
8787
}
8888

8989
const int n_embd = llama_n_embd(model);
90-
const auto * embeddings = llama_get_embeddings(ctx);
90+
auto * embeddings = llama_get_embeddings(ctx);
91+
92+
// l2-normalize embeddings
93+
float norm = 0;
94+
for (int i = 0; i < n_embd; i++) {
95+
norm += embeddings[i] * embeddings[i];
96+
}
97+
norm = sqrt(norm);
98+
for (int i = 0; i < n_embd; i++) {
99+
embeddings[i] /= norm;
100+
}
91101

92102
for (int i = 0; i < n_embd; i++) {
93103
printf("%f ", embeddings[i]);

‎gguf-py/gguf/constants.py

Copy file name to clipboardExpand all lines: gguf-py/gguf/constants.py
+25-18Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Attention:
5050
VALUE_LENGTH = "{arch}.attention.value_length"
5151
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
5252
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
53+
CAUSAL = "{arch}.attention.causal"
5354

5455
class Rope:
5556
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -60,22 +61,23 @@ class Rope:
6061
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
6162

6263
class Tokenizer:
63-
MODEL = "tokenizer.ggml.model"
64-
LIST = "tokenizer.ggml.tokens"
65-
TOKEN_TYPE = "tokenizer.ggml.token_type"
66-
SCORES = "tokenizer.ggml.scores"
67-
MERGES = "tokenizer.ggml.merges"
68-
BOS_ID = "tokenizer.ggml.bos_token_id"
69-
EOS_ID = "tokenizer.ggml.eos_token_id"
70-
UNK_ID = "tokenizer.ggml.unknown_token_id"
71-
SEP_ID = "tokenizer.ggml.seperator_token_id"
72-
PAD_ID = "tokenizer.ggml.padding_token_id"
73-
ADD_BOS = "tokenizer.ggml.add_bos_token"
74-
ADD_EOS = "tokenizer.ggml.add_eos_token"
75-
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
76-
HF_JSON = "tokenizer.huggingface.json"
77-
RWKV = "tokenizer.rwkv.world"
78-
CHAT_TEMPLATE = "tokenizer.chat_template"
64+
MODEL = "tokenizer.ggml.model"
65+
LIST = "tokenizer.ggml.tokens"
66+
TOKEN_TYPE = "tokenizer.ggml.token_type"
67+
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types
68+
SCORES = "tokenizer.ggml.scores"
69+
MERGES = "tokenizer.ggml.merges"
70+
BOS_ID = "tokenizer.ggml.bos_token_id"
71+
EOS_ID = "tokenizer.ggml.eos_token_id"
72+
UNK_ID = "tokenizer.ggml.unknown_token_id"
73+
SEP_ID = "tokenizer.ggml.seperator_token_id"
74+
PAD_ID = "tokenizer.ggml.padding_token_id"
75+
ADD_BOS = "tokenizer.ggml.add_bos_token"
76+
ADD_EOS = "tokenizer.ggml.add_eos_token"
77+
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
78+
HF_JSON = "tokenizer.huggingface.json"
79+
RWKV = "tokenizer.rwkv.world"
80+
CHAT_TEMPLATE = "tokenizer.chat_template"
7981

8082

8183
#
@@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum):
122124
ATTN_OUT = auto()
123125
ATTN_NORM = auto()
124126
ATTN_NORM_2 = auto()
127+
ATTN_OUT_NORM = auto()
125128
ATTN_ROT_EMBD = auto()
126129
FFN_GATE_INP = auto()
127130
FFN_NORM = auto()
@@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum):
134137
FFN_UP_EXP = auto()
135138
ATTN_Q_NORM = auto()
136139
ATTN_K_NORM = auto()
140+
LAYER_OUT_NORM = auto()
137141

138142

139143
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -178,6 +182,7 @@ class MODEL_TENSOR(IntEnum):
178182
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
179183
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
180184
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
185+
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
181186
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
182187
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
183188
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
@@ -187,6 +192,7 @@ class MODEL_TENSOR(IntEnum):
187192
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
188193
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
189194
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
195+
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
190196
}
191197

192198
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -262,17 +268,18 @@ class MODEL_TENSOR(IntEnum):
262268
],
263269
MODEL_ARCH.BERT: [
264270
MODEL_TENSOR.TOKEN_EMBD,
271+
MODEL_TENSOR.TOKEN_EMBD_NORM,
265272
MODEL_TENSOR.TOKEN_TYPES,
266273
MODEL_TENSOR.POS_EMBD,
267274
MODEL_TENSOR.OUTPUT_NORM,
268-
MODEL_TENSOR.ATTN_NORM,
275+
MODEL_TENSOR.ATTN_OUT_NORM,
269276
MODEL_TENSOR.ATTN_Q,
270277
MODEL_TENSOR.ATTN_K,
271278
MODEL_TENSOR.ATTN_V,
272279
MODEL_TENSOR.ATTN_OUT,
273-
MODEL_TENSOR.FFN_NORM,
274280
MODEL_TENSOR.FFN_DOWN,
275281
MODEL_TENSOR.FFN_UP,
282+
MODEL_TENSOR.LAYER_OUT_NORM,
276283
],
277284
MODEL_ARCH.MPT: [
278285
MODEL_TENSOR.TOKEN_EMBD,

‎gguf-py/gguf/gguf_writer.py

Copy file name to clipboardExpand all lines: gguf-py/gguf/gguf_writer.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,9 @@ def add_layer_norm_eps(self, value: float) -> None:
357357
def add_layer_norm_rms_eps(self, value: float) -> None:
358358
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
359359

360+
def add_causal_attention(self, value: bool) -> None:
361+
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
362+
360363
def add_rope_dimension_count(self, count: int) -> None:
361364
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
362365

@@ -387,6 +390,9 @@ def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[by
387390
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
388391
self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
389392

393+
def add_token_type_count(self, value: int) -> None:
394+
self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value)
395+
390396
def add_token_scores(self, scores: Sequence[float]) -> None:
391397
self.add_array(Keys.Tokenizer.SCORES, scores)
392398

‎gguf-py/gguf/tensor_mapping.py

Copy file name to clipboardExpand all lines: gguf-py/gguf/tensor_mapping.py
+10-3Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class TensorNameMap:
3030
# Normalization of token embeddings
3131
MODEL_TENSOR.TOKEN_EMBD_NORM: (
3232
"word_embeddings_layernorm", # bloom
33+
"embeddings.LayerNorm", # bert
3334
),
3435

3536
# Position embeddings
@@ -54,7 +55,6 @@ class TensorNameMap:
5455
"transformer.ln_f", # gpt2 gpt-j falcon
5556
"model.norm", # llama-hf baichuan internlm2
5657
"norm", # llama-pth
57-
"embeddings.LayerNorm", # bert
5858
"transformer.norm_f", # mpt
5959
"ln_f", # refact bloom qwen gpt2
6060
"language_model.encoder.final_layernorm", # persimmon
@@ -79,7 +79,6 @@ class TensorNameMap:
7979
"transformer.h.{bid}.ln_mlp", # falcon40b
8080
"model.layers.{bid}.input_layernorm", # llama-hf
8181
"layers.{bid}.attention_norm", # llama-pth
82-
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
8382
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
8483
"model.layers.{bid}.ln1", # yi
8584
"h.{bid}.ln_1", # gpt2
@@ -155,6 +154,11 @@ class TensorNameMap:
155154
"model.layers.{bid}.attention.wo", # internlm2
156155
),
157156

157+
# Attention output norm
158+
MODEL_TENSOR.ATTN_OUT_NORM: (
159+
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
160+
),
161+
158162
# Rotary embeddings
159163
MODEL_TENSOR.ATTN_ROT_EMBD: (
160164
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
@@ -171,7 +175,6 @@ class TensorNameMap:
171175
"transformer.blocks.{bid}.norm_2", # mpt
172176
"model.layers.{bid}.post_attention_layernorm", # llama-hf
173177
"layers.{bid}.ffn_norm", # llama-pth
174-
"encoder.layer.{bid}.output.LayerNorm", # bert
175178
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
176179
"model.layers.{bid}.ln2", # yi
177180
"h.{bid}.ln_2", # gpt2
@@ -266,6 +269,10 @@ class TensorNameMap:
266269
MODEL_TENSOR.ROPE_FREQS: (
267270
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
268271
),
272+
273+
MODEL_TENSOR.LAYER_OUT_NORM: (
274+
"encoder.layer.{bid}.output.LayerNorm", # bert
275+
)
269276
}
270277

271278
mapping: dict[str, tuple[MODEL_TENSOR, str]]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.