Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d79d8f3

Browse filesBrowse files
authored
vulkan: multi-row k quants (ggml-org#10846)
* multi row k quant shaders! * better row selection * more row choices * readjust row selection * rm_kq=2 by default
1 parent d283d02 commit d79d8f3
Copy full SHA for d79d8f3

File tree

Expand file treeCollapse file tree

6 files changed

+477
-372
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+477
-372
lines changed

‎ggml/src/ggml-vulkan/ggml-vulkan.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-vulkan/ggml-vulkan.cpp
+43-38Lines changed: 43 additions & 38 deletions
Large diffs are not rendered by default.

‎ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

Copy file name to clipboardExpand all lines: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
+77-57Lines changed: 77 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,15 @@
66
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

88
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
layout (constant_id = 1) const uint NUM_ROWS = 1;
910

10-
shared FLOAT_TYPE tmp[BLOCK_SIZE];
11-
12-
void main() {
13-
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
14-
15-
if (row >= p.stride_d) {
16-
return;
17-
}
11+
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
1812

13+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1914
uint a_offset, b_offset, d_offset;
2015
get_offsets(a_offset, b_offset, d_offset);
2116

2217
const uint num_blocks_per_row = p.ncols / QUANT_K;
23-
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2418

2519
// 16 threads are used to process each block
2620
const uint it_size = gl_WorkGroupSize.x/16;
@@ -38,15 +32,15 @@ void main() {
3832
const uint s_offset = 8*v_im;
3933
const uint y_offset = 128*v_im + l0;
4034

41-
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
35+
FLOAT_TYPE temp[NUM_ROWS];
36+
37+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
38+
temp[i] = FLOAT_TYPE(0);
39+
}
4240

4341
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4442
const uint y_idx = i * QUANT_K + y_offset;
4543

46-
f16vec2 d = data_a[ib0 + i].d;
47-
const FLOAT_TYPE dall = d.x;
48-
const FLOAT_TYPE dmin = d.y;
49-
5044
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
5145
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
5246
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -56,58 +50,84 @@ void main() {
5650
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
5751
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
5852

59-
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60-
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61-
62-
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63-
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64-
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65-
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66-
67-
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68-
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69-
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70-
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71-
72-
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73-
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74-
uvec2 qs0 = uvec2(unpack8(qs0_u16));
75-
uvec2 qs16 = uvec2(unpack8(qs16_u16));
76-
77-
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78-
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79-
[[unroll]] for (int l = 0; l < 2; ++l) {
80-
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81-
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82-
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83-
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84-
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85-
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86-
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87-
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88-
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89-
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90-
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91-
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92-
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93-
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94-
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95-
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
53+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
54+
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
55+
f16vec2 d = data_a[ib0 + i].d;
56+
const FLOAT_TYPE dall = d.x;
57+
const FLOAT_TYPE dmin = d.y;
58+
59+
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
60+
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
61+
62+
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
63+
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
64+
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
65+
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;
66+
67+
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
68+
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
69+
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
70+
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));
71+
72+
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
73+
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
74+
uvec2 qs0 = uvec2(unpack8(qs0_u16));
75+
uvec2 qs16 = uvec2(unpack8(qs16_u16));
76+
77+
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
78+
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
79+
[[unroll]] for (int l = 0; l < 2; ++l) {
80+
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
81+
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
82+
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
83+
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
84+
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
85+
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
86+
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
87+
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
88+
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
89+
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
90+
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
91+
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
92+
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
93+
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
94+
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
95+
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
96+
}
97+
temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
9698
}
97-
temp = fma(dall, sum1, fma(-dmin, sum2, temp));
9899
}
99100

100-
tmp[gl_LocalInvocationID.x] = temp;
101-
102101
// sum up partial sums and write back result
102+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
103+
tmpsh[n][tid] = temp[n];
104+
}
103105
barrier();
104-
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
106+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
105107
if (tid < s) {
106-
tmp[tid] += tmp[tid + s];
108+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
109+
tmpsh[n][tid] += tmpsh[n][tid + s];
110+
}
107111
}
108112
barrier();
109113
}
110114
if (tid == 0) {
111-
data_d[d_offset + row] = D_TYPE(tmp[0]);
115+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
116+
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
117+
}
118+
}
119+
}
120+
121+
void main() {
122+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
123+
124+
// do NUM_ROWS at a time, unless there aren't enough remaining rows
125+
if (first_row + NUM_ROWS <= p.stride_d) {
126+
compute_outputs(first_row, NUM_ROWS);
127+
} else {
128+
if (first_row >= p.stride_d) {
129+
return;
130+
}
131+
compute_outputs(first_row, p.stride_d - first_row);
112132
}
113133
}

‎ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

Copy file name to clipboardExpand all lines: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
+62-42Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,15 @@
66
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
77

88
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
9+
layout (constant_id = 1) const uint NUM_ROWS = 1;
910

10-
shared FLOAT_TYPE tmp[BLOCK_SIZE];
11-
12-
void main() {
13-
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
14-
15-
if (row >= p.stride_d) {
16-
return;
17-
}
11+
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
1812

13+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
1914
uint a_offset, b_offset, d_offset;
2015
get_offsets(a_offset, b_offset, d_offset);
2116

2217
const uint num_blocks_per_row = p.ncols / QUANT_K;
23-
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
2418

2519
// 16 threads are used to process each block
2620
const uint it_size = gl_WorkGroupSize.x/16;
@@ -35,19 +29,21 @@ void main() {
3529

3630
const uint8_t m = uint8_t(1 << (4 * v_im));
3731

38-
const uint l0 = 2*v_in; // 0...15
32+
const uint l0 = 2*v_in; // 0...15
3933
const uint q_offset = 32*v_im + l0;
4034
const uint y_offset = 128*v_im + l0;
4135

42-
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
36+
FLOAT_TYPE temp[NUM_ROWS];
37+
38+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
39+
temp[i] = FLOAT_TYPE(0);
40+
}
4341

4442
const uint s_shift = 4 * v_im;
4543

4644
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
4745
const uint y_idx = i * QUANT_K + y_offset;
4846

49-
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
50-
5147
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
5248
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
5349
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
@@ -57,44 +53,68 @@ void main() {
5753
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
5854
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
5955

60-
uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
61-
uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
62-
uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
63-
uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
64-
uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
65-
uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
66-
u8vec2 s0 = unpack8(s0_16);
67-
u8vec2 s2 = unpack8(s2_16);
68-
u8vec2 s4 = unpack8(s4_16);
69-
u8vec2 s6 = unpack8(s6_16);
70-
u8vec2 s8 = unpack8(s8_16);
71-
u8vec2 s10 = unpack8(s10_16);
72-
73-
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
74-
[[unroll]] for (int l = 0; l < 2; ++l) {
75-
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
76-
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
77-
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
78-
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
79-
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
80-
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
81-
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
82-
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
56+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
57+
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
58+
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
59+
60+
uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0];
61+
uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1];
62+
uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2];
63+
uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3];
64+
uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4];
65+
uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5];
66+
u8vec2 s0 = unpack8(s0_16);
67+
u8vec2 s2 = unpack8(s2_16);
68+
u8vec2 s4 = unpack8(s4_16);
69+
u8vec2 s6 = unpack8(s6_16);
70+
u8vec2 s8 = unpack8(s8_16);
71+
u8vec2 s10 = unpack8(s10_16);
72+
73+
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
74+
[[unroll]] for (int l = 0; l < 2; ++l) {
75+
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
76+
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
77+
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
78+
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
79+
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
80+
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
81+
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
82+
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
83+
}
84+
temp[n] = fma(d, sum, temp[n]);
8385
}
84-
temp = fma(d, sum, temp);
8586
}
8687

87-
tmp[gl_LocalInvocationID.x] = temp;
88-
8988
// sum up partial sums and write back result
89+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
90+
tmpsh[n][tid] = temp[n];
91+
}
9092
barrier();
91-
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
93+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
9294
if (tid < s) {
93-
tmp[tid] += tmp[tid + s];
95+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
96+
tmpsh[n][tid] += tmpsh[n][tid + s];
97+
}
9498
}
9599
barrier();
96100
}
97101
if (tid == 0) {
98-
data_d[d_offset + row] = D_TYPE(tmp[0]);
102+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
103+
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
104+
}
105+
}
106+
}
107+
108+
void main() {
109+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
110+
111+
// do NUM_ROWS at a time, unless there aren't enough remaining rows
112+
if (first_row + NUM_ROWS <= p.stride_d) {
113+
compute_outputs(first_row, NUM_ROWS);
114+
} else {
115+
if (first_row >= p.stride_d) {
116+
return;
117+
}
118+
compute_outputs(first_row, p.stride_d - first_row);
99119
}
100120
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.