Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 48d9f25

Browse filesBrowse files
fix padding, GQA
1 parent 8d56dad commit 48d9f25
Copy full SHA for 48d9f25

File tree

Expand file treeCollapse file tree

2 files changed

+5
-5
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+5
-5
lines changed

‎ggml-cuda.cu

Copy file name to clipboardExpand all lines: ggml-cuda.cu
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7552,9 +7552,9 @@ static __global__ void flash_attn_ext_f16(
75527552
__builtin_assume(tid < nthreads);
75537553
constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts.
75547554

7555-
const float * Q_f = (const float *) (Q + nb02*blockIdx.y + ncols*nb01*blockIdx.x);
7556-
const half * K_h = (const half *) (K + nb12*blockIdx.y);
7557-
const half * V_h = (const half *) (V + nb12*blockIdx.y); // K and V have same shape
7555+
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x);
7556+
const half * K_h = (const half *) (K + nb12*(blockIdx.y % ne12));
7557+
const half * V_h = (const half *) (V + nb12*(blockIdx.y % ne12)); // K and V have same shape
75587558
const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2;
75597559

75607560
const int stride_Q = nb01 / sizeof(float);

‎llama.cpp

Copy file name to clipboardExpand all lines: llama.cpp
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9166,7 +9166,7 @@ static int llama_decode_internal(
91669166
// a heuristic, to avoid attending the full cache if it is not yet utilized
91679167
// after enough generations, the benefit from this heuristic disappears
91689168
// if we start defragmenting the cache, the benefit from this will be more important
9169-
kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
9169+
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
91709170
//kv_self.n = llama_kv_cache_cell_max(kv_self);
91719171
}
91729172
}
@@ -13083,7 +13083,7 @@ struct llama_context * llama_new_context_with_model(
1308313083
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
1308413084

1308513085
// this is necessary due to kv_self.n being padded later during inference
13086-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
13086+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
1308713087

1308813088
// with causal attention, the batch size is limited by the context size
1308913089
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.