Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ee19a4a

Browse filesBrowse files
fix KV cache padding, NaN from INFINITY (#6438)
1 parent c63dfdf commit ee19a4a
Copy full SHA for ee19a4a

File tree

Expand file treeCollapse file tree

2 files changed

+11
-8
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+11
-8
lines changed

‎ggml-cuda/fattn.cu

Copy file name to clipboardExpand all lines: ggml-cuda/fattn.cu
+9-6Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <mma.h>
55

66
#define FATTN_KQ_STRIDE 256
7+
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
78

89
template<int D, int parallel_blocks> // D == head size
910
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
@@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16(
5960
KQ[tid] = -INFINITY;
6061
half2 * KQ2 = (half2 *) KQ;
6162

62-
half kqmax = -INFINITY;
63+
half kqmax = -HALF_MAX_HALF;
6364
half kqsum = 0.0f;
6465

6566
__shared__ half kqmax_shared[WARP_SIZE];
6667
__shared__ half kqsum_shared[WARP_SIZE];
6768
if (threadIdx.y == 0) {
68-
kqmax_shared[threadIdx.x] = -INFINITY;
69+
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
6970
kqsum_shared[threadIdx.x] = 0.0f;
7071
}
7172
__syncthreads();
@@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16(
139140
if (tid < D) {
140141
#pragma unroll
141142
for (int k0 = 0; k0 < D; k0 += 2) {
142-
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
143+
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
143144
break;
144145
}
145146

@@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16(
253254
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
254255
half2 * KQ2 = (half2 *) KQ;
255256

256-
half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}};
257-
half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}};
258-
half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}};
257+
half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}};
258+
half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}};
259+
half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}};
259260

260261
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
261262
half2 * VKQ2 = (half2 *) VKQ;
@@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
578579
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
579580
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
580581

582+
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
583+
581584
ggml_cuda_set_device(ctx.device);
582585

583586
const cudaStream_t main_stream = ctx.stream();

‎llama.cpp

Copy file name to clipboardExpand all lines: llama.cpp
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9973,7 +9973,7 @@ static int llama_decode_internal(
99739973
// a heuristic, to avoid attending the full cache if it is not yet utilized
99749974
// after enough generations, the benefit from this heuristic disappears
99759975
// if we start defragmenting the cache, the benefit from this will be more important
9976-
kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
9976+
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
99779977
//kv_self.n = llama_kv_cache_cell_max(kv_self);
99789978
}
99799979
}
@@ -13909,7 +13909,7 @@ struct llama_context * llama_new_context_with_model(
1390913909
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
1391013910

1391113911
// this is necessary due to kv_self.n being padded later during inference
13912-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
13912+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
1391313913

1391413914
// with causal attention, the batch size is limited by the context size
1391513915
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.