Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c446b2e

Browse filesBrowse files
authored
vulkan: Submit once enough matmul work has been recorded (ggml-org#12406)
I've been seeing significantly worse performance for tg with flash attention enabled vs disabled, and it seems to be related to the submit heuristic. Change the heuristic to check how many bytes worth of weight matrix are used and flush every 100MB, and ramp up after the first few submits. This seems to resolve the issue, and also increases perf for non-FA a bit.
1 parent d84635b commit c446b2e
Copy full SHA for c446b2e

File tree

Expand file treeCollapse file tree

1 file changed

+21
-11
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+21
-11
lines changed

‎ggml/src/ggml-vulkan/ggml-vulkan.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-vulkan/ggml-vulkan.cpp
+21-11Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8436,8 +8436,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
84368436
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
84378437
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
84388438

8439+
uint64_t total_mat_mul_bytes = 0;
84398440
for (int i = 0; i < cgraph->n_nodes; i++) {
84408441
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
8442+
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8443+
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8444+
}
84418445
}
84428446
if (ctx->device->need_compiles) {
84438447
ggml_vk_load_shaders(ctx->device);
@@ -8458,17 +8462,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
84588462
bool first_node_in_batch = true; // true if next node will be first node in a batch
84598463
int submit_node_idx = 0; // index to first node in a batch
84608464

8461-
// Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
8462-
// Start with a smaller count to get work submitted right away, and increase it after each submit.
8463-
int nodes_per_submit = 20;
8465+
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
8466+
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
8467+
// (and scaled down based on model size, so smaller models submit earlier).
8468+
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
8469+
int nodes_per_submit = 100;
84648470
int submitted_nodes = 0;
84658471
int submit_count = 0;
8472+
uint64_t mul_mat_bytes = 0;
8473+
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
84668474
for (int i = 0; i < cgraph->n_nodes; i++) {
84678475
if (first_node_in_batch) {
84688476
submit_node_idx = i;
84698477
}
84708478

8471-
bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
8479+
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8480+
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8481+
}
8482+
8483+
bool submit = (submitted_nodes >= nodes_per_submit) ||
8484+
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8485+
(i == last_node);
84728486

84738487
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
84748488

@@ -8485,13 +8499,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
84858499
if (submit) {
84868500
first_node_in_batch = true;
84878501
submitted_nodes = 0;
8488-
switch (submit_count) {
8489-
case 0:
8490-
nodes_per_submit = 50;
8491-
break;
8492-
default:
8493-
nodes_per_submit = 100;
8494-
break;
8502+
mul_mat_bytes = 0;
8503+
if (submit_count < 3) {
8504+
mul_mat_bytes_per_submit *= 2;
84958505
}
84968506
submit_count++;
84978507
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.