Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 75422e8

Browse filesBrowse files
authored
graph : normalize Q, K, V shapes + sync cross attention (ggml-org#12449)
* graph : normalize Q, K, V shapes and add comments ggml-ci * context : synchronize before getting cross attention data * model : fix command-r attention norm check
1 parent bb115d2 commit 75422e8
Copy full SHA for 75422e8

File tree

Expand file treeCollapse file tree

4 files changed

+403
-247
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+403
-247
lines changed

‎src/llama-context.cpp

Copy file name to clipboardExpand all lines: src/llama-context.cpp
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,8 @@ int llama_context::encode(llama_batch & inp_batch) {
11431143
if (model.arch == LLM_ARCH_T5 && t_embd) {
11441144
//cross.t_embd = t_embd;
11451145

1146+
synchronize();
1147+
11461148
cross.n_embd = t_embd->ne[0];
11471149
cross.n_enc = t_embd->ne[1];
11481150
cross.v_embd.resize(cross.n_embd*cross.n_enc);

‎src/llama-graph.cpp

Copy file name to clipboardExpand all lines: src/llama-graph.cpp
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
13781378
// note: storing RoPE-ed version of K in the KV cache
13791379
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
13801380

1381-
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
1381+
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
13821382

13831383
ggml_tensor * v_cache_view = nullptr;
13841384

‎src/llama-graph.h

Copy file name to clipboardExpand all lines: src/llama-graph.h
+12-12Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ struct llm_graph_context {
487487

488488
ggml_tensor * build_attn_mha(
489489
ggml_cgraph * gf,
490-
ggml_tensor * q,
491-
ggml_tensor * k,
492-
ggml_tensor * v,
490+
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
491+
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
492+
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
493493
ggml_tensor * kq_b,
494494
ggml_tensor * kq_mask,
495495
bool v_trans,
@@ -502,9 +502,9 @@ struct llm_graph_context {
502502
ggml_cgraph * gf,
503503
ggml_tensor * wo,
504504
ggml_tensor * wo_b,
505-
ggml_tensor * q_cur,
506-
ggml_tensor * k_cur,
507-
ggml_tensor * v_cur,
505+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
506+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
507+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
508508
ggml_tensor * kq_b,
509509
float kq_scale,
510510
int il) const;
@@ -516,9 +516,9 @@ struct llm_graph_context {
516516
ggml_cgraph * gf,
517517
ggml_tensor * wo,
518518
ggml_tensor * wo_b,
519-
ggml_tensor * q_cur,
520-
ggml_tensor * k_cur,
521-
ggml_tensor * v_cur,
519+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
520+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
521+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
522522
ggml_tensor * kq_b,
523523
float kq_scale,
524524
int il) const;
@@ -530,9 +530,9 @@ struct llm_graph_context {
530530
ggml_cgraph * gf,
531531
ggml_tensor * wo,
532532
ggml_tensor * wo_b,
533-
ggml_tensor * q_cur,
534-
ggml_tensor * k_cur,
535-
ggml_tensor * v_cur,
533+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
534+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
535+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
536536
ggml_tensor * kq_b,
537537
float kq_scale,
538538
int il) const;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.