Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fb085fa

Browse filesBrowse files
committed
cuda : fix to F16 scalars + tune warps for RTX 2060
1 parent 2c04bee commit fb085fa
Copy full SHA for fb085fa

File tree

Expand file treeCollapse file tree

2 files changed

+61
-47
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+61
-47
lines changed

‎ggml-cuda.cu

Copy file name to clipboardExpand all lines: ggml-cuda.cu
+49-45Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6491,8 +6491,8 @@ static __global__ void flash_attn_ext_f16(
64916491
__syncthreads();
64926492

64936493
{
6494-
float S[Q];
6495-
float M[Q];
6494+
half S[Q];
6495+
half M[Q];
64966496

64976497
for(int i = 0; i < Q; i++) {
64986498
S[i] = 0.0f;
@@ -6579,67 +6579,68 @@ static __global__ void flash_attn_ext_f16(
65796579
}
65806580

65816581
// used to detect blocks full of -INF
6582-
float smax = -INFINITY;
6582+
half smax = -INFINITY;
65836583

65846584
// online softmax
65856585
if (C == 32) {
65866586
for (int64_t j = 0; j < Q; ++j) {
65876587
const int64_t p = lane_id;
65886588

6589-
const float m = M[j];
6590-
const float s = __half2float(ss[j*T + p]);
6589+
const half m = M[j];
6590+
const half s = ss[j*T + p];
65916591

6592-
smax = warp_reduce_max(max(smax, s));
6593-
M[j] = warp_reduce_max(max(M[j], s));
6592+
smax = warp_reduce_max(__hmax(smax, s));
6593+
M[j] = warp_reduce_max(__hmax(M[j], s));
65946594

6595-
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6596-
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6595+
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
6596+
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
65976597

65986598
S[j] = S[j]*ms + warp_reduce_sum(vs);
65996599

66006600
// create a QxQ diagonal matrix for rescaling the output
66016601
if (p == j) {
6602-
ss[j*T + C + j] = __float2half(ms);
6602+
ss[j*T + C + j] = ms;
66036603
}
66046604

66056605
// the P matrix from the paper (Q rows, C columns)
6606-
ss[j*T + p] = __float2half(vs);
6606+
ss[j*T + p] = vs;
66076607
}
66086608
} else {
66096609
for (int64_t j = 0; j < Q; ++j) {
6610-
const float m = M[j];
6610+
const half m = M[j];
66116611

66126612
for (int64_t p = lane_id; p < C; p += NW) {
6613-
const float s = __half2float(ss[j*T + p]);
6613+
const half s = ss[j*T + p];
66146614

6615-
smax = warp_reduce_max(max(smax, s));
6616-
M[j] = warp_reduce_max(max(M[j], s));
6615+
smax = warp_reduce_max(__hmax(smax, s));
6616+
M[j] = warp_reduce_max(__hmax(M[j], s));
66176617
}
66186618

6619-
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6619+
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
66206620

66216621
S[j] = S[j]*ms;
66226622

66236623
// create a QxQ diagonal matrix for rescaling the output
66246624
if (lane_id == j) {
6625-
ss[j*T + C + j] = __float2half(ms);
6625+
ss[j*T + C + j] = ms;
66266626
}
66276627

66286628
for (int64_t p = lane_id; p < C; p += NW) {
6629-
const float s = __half2float(ss[j*T + p]);
6629+
const half s = ss[j*T + p];
66306630

6631-
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6631+
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
66326632

66336633
S[j] = S[j] + warp_reduce_sum(vs);
66346634

66356635
// the P matrix from the paper (Q rows, C columns)
6636-
ss[j*T + p] = __float2half(vs);
6636+
ss[j*T + p] = vs;
66376637
}
66386638
}
66396639
}
66406640

6641+
66416642
// skip -INF blocks
6642-
if (smax == -INFINITY) {
6643+
if (__hisinf(smax)) {
66436644
continue;
66446645
}
66456646

@@ -6686,16 +6687,16 @@ static __global__ void flash_attn_ext_f16(
66866687
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
66876688
for (int64_t j = 0; j < Q; ++j) {
66886689
if (lane_id == 0) {
6689-
ss[j*T + 0] = __float2half(S[j]);
6690-
ss[j*T + 1] = __float2half(M[j]);
6690+
ss[j*T + 0] = S[j];
6691+
ss[j*T + 1] = M[j];
66916692
}
66926693
}
66936694
}
66946695

66956696
// reduce the warps sequentially
66966697
for (int64_t sg = 1; sg < num_warps; ++sg) {
6697-
float S = 0.0f;
6698-
float M = -INFINITY;
6698+
half S = 0.0f;
6699+
half M = -INFINITY;
66996700

67006701
__syncthreads();
67016702

@@ -6713,25 +6714,25 @@ static __global__ void flash_attn_ext_f16(
67136714
// the first simdgroup accumulates the results from the other simdgroups
67146715
if (warp_id == 0) {
67156716
for (int64_t j = 0; j < Q; ++j) {
6716-
const float S0 = __half2float(ss[j*T + 0]);
6717-
const float S1 = __half2float(ss[j*T + sg*SH + 0]);
6717+
const half S0 = ss[j*T + 0];
6718+
const half S1 = ss[j*T + sg*SH + 0];
67186719

6719-
const float M0 = __half2float(ss[j*T + 1]);
6720-
const float M1 = __half2float(ss[j*T + sg*SH + 1]);
6720+
const half M0 = ss[j*T + 1];
6721+
const half M1 = ss[j*T + sg*SH + 1];
67216722

6722-
M = max(M0, M1);
6723+
M = __hmax(M0, M1);
67236724

6724-
const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
6725-
const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
6725+
const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M);
6726+
const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M);
67266727

67276728
S = S0*ms0 + S1*ms1;
67286729

67296730
if (lane_id == 0) {
6730-
ss[j*T + 0] = __float2half(S);
6731-
ss[j*T + 1] = __float2half(M);
6731+
ss[j*T + 0] = S;
6732+
ss[j*T + 1] = M;
67326733

6733-
ss[j*T + C + j ] = __float2half(ms0);
6734-
ss[j*T + C + j + sg*SH] = __float2half(ms1);
6734+
ss[j*T + C + j ] = ms0;
6735+
ss[j*T + C + j + sg*SH] = ms1;
67356736
}
67366737
}
67376738

@@ -6774,10 +6775,10 @@ static __global__ void flash_attn_ext_f16(
67746775
// final rescale with 1/S and store to global memory
67756776
if (warp_id == 0) {
67766777
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
6777-
const float S = __half2float(ss[j*T + 0]);
6778+
const half S = ss[j*T + 0];
67786779

67796780
for (int64_t i = lane_id; i < D; i += NW) {
6780-
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
6781+
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
67816782
}
67826783
}
67836784
}
@@ -10930,12 +10931,15 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1093010931
float scale;
1093110932
memcpy(&scale, KQV->op_params, sizeof(float));
1093210933

10933-
const int nqpb = 16; // queries per block
10934-
const int ncpw = 32; // cache values per warp (does not work for other values)
10934+
#define NQPB 16
10935+
#define NCPW 32
10936+
10937+
const int nqpb = NQPB; // queries per block
10938+
const int ncpw = NCPW; // cache values per warp (does not work for other values)
1093510939

1093610940
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
1093710941
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
10938-
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
10942+
const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2;
1093910943

1094010944
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
1094110945
dim3 block_dim(32, nwarps, 1);
@@ -10945,7 +10949,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1094510949
switch (Q->ne[0])
1094610950
{
1094710951
case 16:
10948-
flash_attn_ext_f16<16, 16, 32>
10952+
flash_attn_ext_f16<16, NQPB, NCPW>
1094910953
<<<blocks_num, block_dim, shmem, main_stream>>> (
1095010954
(const char *) src0_extra->data_device[g_main_device], // Query
1095110955
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10962,7 +10966,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1096210966
);
1096310967
break;
1096410968
case 64:
10965-
flash_attn_ext_f16<64, 16, 32>
10969+
flash_attn_ext_f16<64, NQPB, NCPW>
1096610970
<<<blocks_num, block_dim, shmem, main_stream>>> (
1096710971
(const char *) src0_extra->data_device[g_main_device], // Query
1096810972
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10979,7 +10983,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1097910983
);
1098010984
break;
1098110985
case 80:
10982-
flash_attn_ext_f16<80, 16, 32>
10986+
flash_attn_ext_f16<80, NQPB, NCPW>
1098310987
<<<blocks_num, block_dim, shmem, main_stream>>> (
1098410988
(const char *) src0_extra->data_device[g_main_device], // Query
1098510989
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10996,7 +11000,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1099611000
);
1099711001
break;
1099811002
case 128:
10999-
flash_attn_ext_f16<128, 16, 32>
11003+
flash_attn_ext_f16<128, NQPB, NCPW>
1100011004
<<<blocks_num, block_dim, shmem, main_stream>>> (
1100111005
(const char *) src0_extra->data_device[g_main_device], // Query
1100211006
(const char *) src1_extra->data_device[g_main_device], // Key

‎tests/test-backend-ops.cpp

Copy file name to clipboardExpand all lines: tests/test-backend-ops.cpp
+12-2Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,19 @@ struct test_case {
572572
// duplicate the op
573573
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
574574
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
575+
#if 1
575576
for (int i = 1; i < n_runs; i++) {
576577
gf->nodes[gf->n_nodes++] = out;
577578
}
579+
#else
580+
n_runs = 1000;
581+
int n_nodes = gf->n_nodes;
582+
for (int i = 1; i < n_runs; i++) {
583+
for (int j = 0; j < n_nodes; j++) {
584+
gf->nodes[gf->n_nodes++] = gf->nodes[j];
585+
}
586+
}
587+
#endif
578588

579589
// calculate memory
580590
size_t mem = n_runs * op_size(out);
@@ -2199,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
21992209
test_cases.emplace_back(new test_pad());
22002210
test_cases.emplace_back(new test_leaky_relu());
22012211

2202-
#if 0
2203-
for (int hs : { 64, 80, 96, 112, 128, 256, }) {
2212+
#if 1
2213+
for (int hs : { 64, 80, 128, }) {
22042214
for (int nh : { 32, }) {
22052215
for (int kv : { 512, 1024, 2048, 4096, }) {
22062216
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.