4
4
#include < mma.h>
5
5
6
6
#define FATTN_KQ_STRIDE 256
7
+ #define HALF_MAX_HALF __float2half (65504 .0f /2 ) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
7
8
8
9
template<int D, int parallel_blocks> // D == head size
9
10
__launch_bounds__(((D + WARP_SIZE - 1 ) / WARP_SIZE)*WARP_SIZE, 1)
@@ -59,13 +60,13 @@ static __global__ void flash_attn_vec_ext_f16(
59
60
KQ[tid] = -INFINITY;
60
61
half2 * KQ2 = (half2 *) KQ;
61
62
62
- half kqmax = -INFINITY ;
63
+ half kqmax = -HALF_MAX_HALF ;
63
64
half kqsum = 0 .0f ;
64
65
65
66
__shared__ half kqmax_shared[WARP_SIZE];
66
67
__shared__ half kqsum_shared[WARP_SIZE];
67
68
if (threadIdx .y == 0 ) {
68
- kqmax_shared[threadIdx .x ] = -INFINITY ;
69
+ kqmax_shared[threadIdx .x ] = -HALF_MAX_HALF ;
69
70
kqsum_shared[threadIdx .x ] = 0 .0f ;
70
71
}
71
72
__syncthreads ();
@@ -139,7 +140,7 @@ static __global__ void flash_attn_vec_ext_f16(
139
140
if (tid < D) {
140
141
#pragma unroll
141
142
for (int k0 = 0 ; k0 < D; k0 += 2 ) {
142
- if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
143
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
143
144
break ;
144
145
}
145
146
@@ -253,9 +254,9 @@ static __global__ void flash_attn_ext_f16(
253
254
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
254
255
half2 * KQ2 = (half2 *) KQ;
255
256
256
- half2 KQ_rowsum[ncols/nwarps] = {{0 .0f , 0 .0f }};
257
- half2 KQ_max[ncols/nwarps] = {{-INFINITY , -INFINITY }};
258
- half2 KQ_max_scale[ncols/nwarps] = {{0 .0f , 0 .0f }};
257
+ half2 KQ_rowsum[ncols/nwarps] = {{ 0 .0f , 0 .0f }};
258
+ half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF , -HALF_MAX_HALF }};
259
+ half2 KQ_max_scale[ncols/nwarps] = {{ 0 .0f , 0 .0f }};
259
260
260
261
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
261
262
half2 * VKQ2 = (half2 *) VKQ;
@@ -578,6 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
578
579
GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
579
580
" the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
580
581
582
+ GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
583
+
581
584
ggml_cuda_set_device (ctx.device );
582
585
583
586
const cudaStream_t main_stream = ctx.stream ();
0 commit comments