Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f4dea7d

Browse filesBrowse files
simonJJJggerganov
andauthored
llama : add qwen2moe (ggml-org#6074)
* support qwen2moe * fix-review * metal : support unary ops for nelements % 4 != 0 * metal : require contiguousness for float4 unary kernels * metal : require contiguousness for float4 unary kernels (cont) * fix-review * names : for brevity "SHARED_EXP" -> "SHEXP" * llama : reuse build_moe_ffn() * llama : add model type name --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 8a56075 commit f4dea7d
Copy full SHA for f4dea7d

File tree

Expand file treeCollapse file tree

7 files changed

+537
-101
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+537
-101
lines changed

‎convert-hf-to-gguf.py

Copy file name to clipboardExpand all lines: convert-hf-to-gguf.py
+99Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,105 @@ class Qwen2Model(Model):
17001700
model_arch = gguf.MODEL_ARCH.QWEN2
17011701

17021702

1703+
@Model.register("Qwen2MoeForCausalLM")
1704+
class Qwen2MoeModel(Model):
1705+
model_arch = gguf.MODEL_ARCH.QWEN2MOE
1706+
1707+
def set_gguf_parameters(self):
1708+
super().set_gguf_parameters()
1709+
if (n_experts := self.hparams.get("num_experts")) is not None:
1710+
self.gguf_writer.add_expert_count(n_experts)
1711+
1712+
def write_tensors(self):
1713+
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
1714+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1715+
n_experts = self.hparams.get("num_experts")
1716+
experts = dict()
1717+
for name, data_torch in self.get_tensors():
1718+
# we don't need these
1719+
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
1720+
continue
1721+
1722+
old_dtype = data_torch.dtype
1723+
1724+
# convert any unsupported data types to float32
1725+
if data_torch.dtype not in (torch.float16, torch.float32):
1726+
data_torch = data_torch.to(torch.float32)
1727+
1728+
data = data_torch.squeeze().numpy()
1729+
1730+
# process the experts separately
1731+
if name.find("experts") != -1:
1732+
experts[name] = data
1733+
if len(experts) >= n_experts * 3:
1734+
# merge the experts into a single 3d tensor
1735+
for bid in range(block_count):
1736+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
1737+
full = True
1738+
for xid in range(n_experts):
1739+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
1740+
if ename not in experts:
1741+
full = False
1742+
break
1743+
if not full:
1744+
continue
1745+
1746+
datas = []
1747+
for xid in range(n_experts):
1748+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
1749+
datas.append(experts[ename])
1750+
del experts[ename]
1751+
1752+
data = np.stack(datas, axis=0)
1753+
data_dtype = data.dtype
1754+
1755+
if self.ftype == 0 and data_dtype == np.float16:
1756+
data = data.astype(np.float32)
1757+
1758+
if self.ftype == 1 and data_dtype == np.float32:
1759+
data = data.astype(np.float16)
1760+
1761+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
1762+
1763+
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
1764+
if new_name is None:
1765+
print(f"Can not map tensor {name!r}")
1766+
sys.exit()
1767+
1768+
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
1769+
1770+
self.gguf_writer.add_tensor(new_name, data)
1771+
continue
1772+
1773+
# map tensor names
1774+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1775+
if new_name is None:
1776+
print(f"Can not map tensor {name!r}")
1777+
sys.exit()
1778+
1779+
n_dims = len(data.shape)
1780+
data_dtype = data.dtype
1781+
1782+
# if f32 desired, convert any float16 to float32
1783+
if self.ftype == 0 and data_dtype == np.float16:
1784+
data = data.astype(np.float32)
1785+
1786+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1787+
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
1788+
data = data.astype(np.float32)
1789+
1790+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1791+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
1792+
data = data.astype(np.float16)
1793+
1794+
print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
1795+
1796+
self.gguf_writer.add_tensor(new_name, data)
1797+
1798+
if len(experts) > 0:
1799+
raise ValueError(f"Unprocessed experts: {experts.keys()}")
1800+
1801+
17031802
@Model.register("GPT2LMHeadModel")
17041803
class GPT2Model(Model):
17051804
model_arch = gguf.MODEL_ARCH.GPT2

‎ggml-metal.m

Copy file name to clipboardExpand all lines: ggml-metal.m
+42-15Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@
4141
GGML_METAL_KERNEL_TYPE_TANH,
4242
GGML_METAL_KERNEL_TYPE_RELU,
4343
GGML_METAL_KERNEL_TYPE_GELU,
44+
GGML_METAL_KERNEL_TYPE_GELU_4,
4445
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
46+
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
4547
GGML_METAL_KERNEL_TYPE_SILU,
48+
GGML_METAL_KERNEL_TYPE_SILU_4,
4649
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
4750
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
4851
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -473,8 +476,11 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
473476
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
474477
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
475478
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
479+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
476480
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
481+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
477482
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
483+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
478484
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
479485
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
480486
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
@@ -1178,6 +1184,9 @@ static enum ggml_status ggml_metal_graph_compute(
11781184
} break;
11791185
case GGML_OP_UNARY:
11801186
switch (ggml_get_unary_op(gf->nodes[i])) {
1187+
// we are not taking into account the strides, so for now require contiguous tensors
1188+
GGML_ASSERT(ggml_is_contiguous(src0));
1189+
11811190
case GGML_UNARY_OP_TANH:
11821191
{
11831192
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
@@ -1204,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
12041213
} break;
12051214
case GGML_UNARY_OP_GELU:
12061215
{
1207-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1216+
int64_t n = ggml_nelements(dst);
1217+
1218+
id<MTLComputePipelineState> pipeline = nil;
1219+
1220+
if (n % 4 == 0) {
1221+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1222+
n /= 4;
1223+
} else {
1224+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1225+
}
12081226

12091227
[encoder setComputePipelineState:pipeline];
12101228
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
12111229
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
12121230

1213-
const int64_t n = ggml_nelements(dst);
1214-
GGML_ASSERT(n % 4 == 0);
1215-
1216-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1231+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
12171232
} break;
12181233
case GGML_UNARY_OP_GELU_QUICK:
12191234
{
1220-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1235+
int64_t n = ggml_nelements(dst);
1236+
1237+
id<MTLComputePipelineState> pipeline = nil;
1238+
1239+
if (n % 4 == 0) {
1240+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1241+
n /= 4;
1242+
} else {
1243+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1244+
}
12211245

12221246
[encoder setComputePipelineState:pipeline];
12231247
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
12241248
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
12251249

1226-
const int64_t n = ggml_nelements(dst);
1227-
GGML_ASSERT(n % 4 == 0);
1228-
1229-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1250+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
12301251
} break;
12311252
case GGML_UNARY_OP_SILU:
12321253
{
1233-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1254+
int64_t n = ggml_nelements(dst);
1255+
1256+
id<MTLComputePipelineState> pipeline = nil;
1257+
1258+
if (n % 4 == 0) {
1259+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1260+
n /= 4;
1261+
} else {
1262+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1263+
}
12341264

12351265
[encoder setComputePipelineState:pipeline];
12361266
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
12371267
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
12381268

1239-
const int64_t n = ggml_nelements(dst);
1240-
GGML_ASSERT(n % 4 == 0);
1241-
1242-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1269+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
12431270
} break;
12441271
default:
12451272
{

‎ggml-metal.metal

Copy file name to clipboardExpand all lines: ggml-metal.metal
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,15 @@ constant float GELU_QUICK_COEF = -1.702f;
242242
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
243243

244244
kernel void kernel_gelu(
245+
device const float * src0,
246+
device float * dst,
247+
uint tpig[[thread_position_in_grid]]) {
248+
device const float & x = src0[tpig];
249+
250+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
251+
}
252+
253+
kernel void kernel_gelu_4(
245254
device const float4 * src0,
246255
device float4 * dst,
247256
uint tpig[[thread_position_in_grid]]) {
@@ -255,6 +264,15 @@ kernel void kernel_gelu(
255264
}
256265

257266
kernel void kernel_gelu_quick(
267+
device const float * src0,
268+
device float * dst,
269+
uint tpig[[thread_position_in_grid]]) {
270+
device const float & x = src0[tpig];
271+
272+
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
273+
}
274+
275+
kernel void kernel_gelu_quick_4(
258276
device const float4 * src0,
259277
device float4 * dst,
260278
uint tpig[[thread_position_in_grid]]) {
@@ -264,6 +282,14 @@ kernel void kernel_gelu_quick(
264282
}
265283

266284
kernel void kernel_silu(
285+
device const float * src0,
286+
device float * dst,
287+
uint tpig[[thread_position_in_grid]]) {
288+
device const float & x = src0[tpig];
289+
dst[tpig] = x / (1.0f + exp(-x));
290+
}
291+
292+
kernel void kernel_silu_4(
267293
device const float4 * src0,
268294
device float4 * dst,
269295
uint tpig[[thread_position_in_grid]]) {

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.