Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1b8fb81

Browse filesBrowse files
ggml: aarch64: Implement SVE F32 kernels for vector functions (ggml-org#13843)
* F32-Mamba-SVE * F32-Mamba-SVE * Resolve test errors-1 * Resolve test errors-2 * F32-vec-SVE * F32-vec-SVE * F32-vec-SVE
1 parent 53ae306 commit 1b8fb81
Copy full SHA for 1b8fb81

File tree

Expand file treeCollapse file tree

4 files changed

+513
-138
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+513
-138
lines changed

‎ggml/src/ggml-cpu/ops.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-cpu/ops.cpp
+143-72Lines changed: 143 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7641,8 +7641,8 @@ static void ggml_compute_forward_ssm_scan_f32(
76417641
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
76427642
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
76437643
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7644-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7644+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
76467646

76477647
// use the output as the source for the next token-wise iterations
76487648
if (i2 > 0) { s0 = s; }
@@ -8070,6 +8070,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80708070
#define GGML_F32X_MUL GGML_F32x16_MUL
80718071
#define GGML_F32X_FMA GGML_F32x16_FMA
80728072
#define WKV_VECTOR_SIZE 16
8073+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8074+
#define GGML_F32X GGML_F32xt
8075+
#define GGML_F32X_SET1 GGML_F32xt_SET1
8076+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
8077+
#define GGML_F32X_STORE GGML_F32xt_STORE
8078+
#define GGML_F32X_MUL GGML_F32xt_MUL
8079+
#define GGML_F32X_FMA GGML_F32xt_FMA
8080+
#define WKV_VECTOR_SIZE 8
80738081
#elif defined(__ARM_NEON) && defined(__aarch64__)
80748082
#define GGML_F32X GGML_F32x4
80758083
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8080,8 +8088,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80808088
#define WKV_VECTOR_SIZE 4
80818089
#endif
80828090

8091+
int wkv_vector_size;
80838092
#ifdef WKV_VECTOR_SIZE
8084-
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8093+
#if defined(__ARM_FEATURE_SVE)
8094+
wkv_vector_size = svcntw();
8095+
#else
8096+
wkv_vector_size = WKV_VECTOR_SIZE;
8097+
#endif
8098+
const int64_t vec_count = head_size / wkv_vector_size;
80858099

80868100
for (int64_t t = 0; t < T; t++) {
80878101
size_t t_offset = t * t_stride;
@@ -8111,7 +8125,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
81118125
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
81128126

81138127
for (int64_t j = 0; j < vec_count; j++) {
8114-
size_t base_j = j * WKV_VECTOR_SIZE;
8128+
size_t base_j = j * wkv_vector_size;
81158129
size_t t_h_j_offset = t_h_offset + base_j;
81168130
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
81178131

@@ -8136,7 +8150,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
81368150
}
81378151

81388152
// Handle remaining elements, this will not be used.
8139-
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
8153+
for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
81408154
size_t t_h_j_offset = t_h_offset + j;
81418155
size_t h_2d_i_j_offset = h_2d_i_offset + j;
81428156
float v_val = v[t_h_j_offset];
@@ -8272,6 +8286,14 @@ static void ggml_compute_forward_gla_f32(
82728286
#define GGML_F32X_MUL GGML_F32x16_MUL
82738287
#define GGML_F32X_FMA GGML_F32x16_FMA
82748288
#define GLA_VECTOR_SIZE 16
8289+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8290+
#define GGML_F32X GGML_F32xt
8291+
#define GGML_F32X_SET1 GGML_F32xt_SET1
8292+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
8293+
#define GGML_F32X_STORE GGML_F32xt_STORE
8294+
#define GGML_F32X_MUL GGML_F32xt_MUL
8295+
#define GGML_F32X_FMA GGML_F32xt_FMA
8296+
#define GLA_VECTOR_SIZE 8
82758297
#elif defined(__ARM_NEON) && defined(__aarch64__)
82768298
#define GGML_F32X GGML_F32x4
82778299
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8282,8 +8304,14 @@ static void ggml_compute_forward_gla_f32(
82828304
#define GLA_VECTOR_SIZE 4
82838305
#endif
82848306

8307+
int gla_vector_size;
82858308
#ifdef GLA_VECTOR_SIZE
8286-
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8309+
#if defined(__ARM_FEATURE_SVE)
8310+
gla_vector_size = svcntw();
8311+
#else
8312+
gla_vector_size = GLA_VECTOR_SIZE;
8313+
#endif
8314+
const int64_t vec_count = head_size / gla_vector_size;
82878315

82888316
for (int64_t t = 0; t < T; t++) {
82898317
size_t t_offset = t * t_stride;
@@ -8310,7 +8338,7 @@ static void ggml_compute_forward_gla_f32(
83108338
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
83118339

83128340
for (int64_t j = 0; j < vec_count; j++) {
8313-
size_t base_j = j * GLA_VECTOR_SIZE;
8341+
size_t base_j = j * gla_vector_size;
83148342
size_t t_h_j_offset = t_h_offset + base_j;
83158343
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
83168344

@@ -8334,7 +8362,7 @@ static void ggml_compute_forward_gla_f32(
83348362
}
83358363

83368364
// Handle remaining elements, this will not be used.
8337-
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
8365+
for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
83388366
size_t t_h_j_offset = t_h_offset + j;
83398367
size_t h_2d_i_j_offset = h_2d_i_offset + j;
83408368
float v_val = v[t_h_j_offset];
@@ -8443,83 +8471,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
84438471
int64_t h_stride_2d = head_size * head_size;
84448472

84458473
#if defined(GGML_SIMD)
8446-
for (int64_t t = 0; t < T; t++) {
8447-
int64_t t_offset = t * t_stride;
8448-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8449-
float * state_cur = state + state_offset;
8450-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8451-
8452-
for (int64_t h = h_start; h < h_end; h++) {
8453-
int64_t h_offset = h * h_stride;
8454-
int64_t t_h_offset = t_offset + h_offset;
8455-
int64_t h_2d_offset = h * h_stride_2d;
8456-
8457-
for (int64_t ii = 0; ii < head_size; ii++) {
8458-
int64_t t_h_i_offset = t_h_offset + ii;
8459-
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8460-
8461-
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8474+
#if defined(__ARM_FEATURE_SVE)
8475+
// scalar Route to scalar implementation //TODO: Write SVE code
8476+
for (int64_t t = 0; t < T; t++) {
8477+
int64_t t_offset = t * t_stride;
8478+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8479+
float * state_cur = state + state_offset;
8480+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8481+
8482+
for (int64_t h = h_start; h < h_end; h++) {
8483+
int64_t h_offset = h * h_stride;
8484+
int64_t t_h_offset = t_offset + h_offset;
8485+
int64_t h_2d_offset = h * h_stride_2d;
8486+
8487+
for (int64_t i = 0; i < head_size; i++) {
8488+
int64_t t_h_i_offset = t_h_offset + i;
8489+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8490+
8491+
float v_val = v[t_h_i_offset];
8492+
8493+
float sa = 0, result = 0;
8494+
for (int64_t j = 0; j < head_size; j++) {
8495+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8496+
}
84628497

8463-
float sa = 0;
8464-
{
8465-
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8466-
GGML_F32_VEC ax[GGML_F32_ARR];
8467-
GGML_F32_VEC ay[GGML_F32_ARR];
8468-
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8469-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8470-
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8471-
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8472-
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8473-
}
8498+
for (int64_t j = 0; j < head_size; j++) {
8499+
int64_t t_h_j_offset = t_h_offset + j;
8500+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8501+
8502+
float r_val = r[t_h_j_offset];
8503+
float w_val = w[t_h_j_offset];
8504+
float k_val = k[t_h_j_offset];
8505+
float b_val = b[t_h_j_offset];
8506+
float kv_val = v_val * k_val;
8507+
float prev_state_val = state_prev[h_2d_i_j_offset];
8508+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8509+
result += state_cur[h_2d_i_j_offset] * r_val;
84748510
}
8475-
GGML_F32_VEC_REDUCE(sa, sum);
8511+
dst_data[t_h_i_offset] = result;
84768512
}
8513+
}
8514+
}
8515+
#else
8516+
for (int64_t t = 0; t < T; t++) {
8517+
int64_t t_offset = t * t_stride;
8518+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8519+
float * state_cur = state + state_offset;
8520+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8521+
8522+
for (int64_t h = h_start; h < h_end; h++) {
8523+
int64_t h_offset = h * h_stride;
8524+
int64_t t_h_offset = t_offset + h_offset;
8525+
int64_t h_2d_offset = h * h_stride_2d;
8526+
8527+
for (int64_t ii = 0; ii < head_size; ii++) {
8528+
int64_t t_h_i_offset = t_h_offset + ii;
8529+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8530+
8531+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8532+
8533+
float sa = 0;
8534+
{
8535+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8536+
GGML_F32_VEC ax[GGML_F32_ARR];
8537+
GGML_F32_VEC ay[GGML_F32_ARR];
8538+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8539+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8540+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8541+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8542+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8543+
}
8544+
}
8545+
GGML_F32_VEC_REDUCE(sa, sum);
8546+
}
84778547

8478-
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
8548+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
84798549

8480-
int64_t j = 0;
8481-
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8482-
for (; j < head_size; j += GGML_F32_STEP) {
8483-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8484-
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8485-
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8550+
int64_t j = 0;
8551+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8552+
for (; j < head_size; j += GGML_F32_STEP) {
8553+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8554+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8555+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
84868556

8487-
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8488-
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8489-
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8490-
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8557+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8558+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8559+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8560+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
84918561

8492-
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
8562+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
84938563

8494-
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8495-
// kv + s * decay + sa * b
8496-
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8497-
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8498-
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8564+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8565+
// kv + s * decay + sa * b
8566+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8567+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8568+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
84998569

8500-
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8570+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8571+
}
8572+
}
8573+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8574+
8575+
// There shouldn't be left-overs though.
8576+
for (; j < head_size; j++) {
8577+
int64_t t_h_j_offset = t_h_offset + j;
8578+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8579+
8580+
float r_val = r[t_h_j_offset];
8581+
float w_val = w[t_h_j_offset];
8582+
float k_val = k[t_h_j_offset];
8583+
float b_val = b[t_h_j_offset];
8584+
float kv_val = v[t_h_i_offset] * k_val;
8585+
8586+
float prev_state_val = state_prev[h_2d_i_j_offset];
8587+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8588+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
85018589
}
8502-
}
8503-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8504-
8505-
// There shouldn't be left-overs though.
8506-
for (; j < head_size; j++) {
8507-
int64_t t_h_j_offset = t_h_offset + j;
8508-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8509-
8510-
float r_val = r[t_h_j_offset];
8511-
float w_val = w[t_h_j_offset];
8512-
float k_val = k[t_h_j_offset];
8513-
float b_val = b[t_h_j_offset];
8514-
float kv_val = v[t_h_i_offset] * k_val;
8515-
8516-
float prev_state_val = state_prev[h_2d_i_j_offset];
8517-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8518-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
85198590
}
85208591
}
85218592
}
8522-
}
8593+
#endif
85238594
#else
85248595
for (int64_t t = 0; t < T; t++) {
85258596
int64_t t_offset = t * t_stride;

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.