Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ebd062b

Browse filesBrowse files
committed
cuda : use 512 threads for soft_max instead of 32
1 parent 580fe20 commit ebd062b
Copy full SHA for ebd062b

File tree

Expand file treeCollapse file tree

1 file changed

+34
-17
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+34
-17
lines changed

‎ggml-cuda.cu

Copy file name to clipboardExpand all lines: ggml-cuda.cu
+34-17Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443
#define CUDA_SCALE_BLOCK_SIZE 256
444444
#define CUDA_CLAMP_BLOCK_SIZE 256
445445
#define CUDA_ROPE_BLOCK_SIZE 256
446+
#define CUDA_SOFT_MAX_BLOCK_SIZE 512
446447
#define CUDA_ALIBI_BLOCK_SIZE 32
447448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -4717,45 +4718,59 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47174718
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47184719
}
47194720

4720-
// the CUDA soft max implementation differs from the CPU implementation
4721-
// instead of doubles floats are used
4721+
// TODO: maybe can be improved with some warp-based primitives
47224722
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723-
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
4723+
const int tid = threadIdx.x;
4724+
const int rowx = blockIdx.x;
47244725
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4725-
const int block_size = blockDim.y;
4726-
const int tid = threadIdx.y;
47274726

4728-
float max_val = -INFINITY;
4727+
const int block_size = blockDim.x;
4728+
4729+
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
4730+
4731+
buf[tid] = -INFINITY;
47294732

47304733
for (int col = tid; col < ncols; col += block_size) {
47314734
const int ix = rowx*ncols + col;
47324735
const int iy = rowy*ncols + col;
4733-
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
4736+
buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f));
47344737
}
47354738

4739+
__syncthreads();
4740+
47364741
// find the max value in the block
4737-
#pragma unroll
4738-
for (int mask = 16; mask > 0; mask >>= 1) {
4739-
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
4742+
for (int i = block_size/2; i > 0; i >>= 1) {
4743+
if (tid < i) {
4744+
buf[tid] = max(buf[tid], buf[tid + i]);
4745+
}
4746+
__syncthreads();
47404747
}
47414748

47424749
float tmp = 0.f;
47434750

47444751
for (int col = tid; col < ncols; col += block_size) {
47454752
const int ix = rowx*ncols + col;
47464753
const int iy = rowy*ncols + col;
4747-
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
4754+
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]);
47484755
tmp += val;
47494756
dst[ix] = val;
47504757
}
47514758

4759+
__syncthreads();
4760+
4761+
buf[tid] = tmp;
4762+
4763+
__syncthreads();
4764+
47524765
// sum up partial sums
4753-
#pragma unroll
4754-
for (int mask = 16; mask > 0; mask >>= 1) {
4755-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
4766+
for (int i = block_size/2; i > 0; i >>= 1) {
4767+
if (tid < i) {
4768+
buf[tid] += buf[tid + i];
4769+
}
4770+
__syncthreads();
47564771
}
47574772

4758-
const float inv_tmp = 1.f / tmp;
4773+
const float inv_tmp = 1.f / buf[0];
47594774

47604775
for (int col = tid; col < ncols; col += block_size) {
47614776
const int i = rowx*ncols + col;
@@ -5796,7 +5811,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
57965811
}
57975812

57985813
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5799-
const dim3 block_dims(1, WARP_SIZE, 1);
5814+
int nth = WARP_SIZE;
5815+
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
5816+
const dim3 block_dims(nth, 1, 1);
58005817
const dim3 block_nums(nrows_x, 1, 1);
58015818
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
58025819
}
@@ -6853,7 +6870,7 @@ inline void ggml_cuda_op_soft_max(
68536870

68546871
const int64_t ne00 = src0->ne[0];
68556872
const int64_t nrows_x = ggml_nrows(src0);
6856-
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0;
6873+
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
68576874

68586875
float scale = 1.0f;
68596876
memcpy(&scale, dst->op_params, sizeof(float));

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.