Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dd8ba93

Browse filesBrowse files
ggml: aarch64: Implement SVE F32 kernels for Mamba Sequential Scan Algorithm (ggml-org#13882)
* F32-Mamba-Seq_Scan-SVE * Fix formatting * ggml : missing space --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 66c9206 commit dd8ba93
Copy full SHA for dd8ba93

File tree

Expand file treeCollapse file tree

2 files changed

+110
-30
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+110
-30
lines changed

‎ggml/src/ggml-cpu/ops.cpp

Copy file name to clipboardExpand all lines: ggml/src/ggml-cpu/ops.cpp
+74-30Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
76337633
const int ir1 = MIN(ir0 + dr, nr);
76347634
const int ir = ir1 - ir0;
76357635

7636-
for (int i3 = 0; i3 < n_s; ++i3) {
7637-
for (int i2 = 0; i2 < n_t; ++i2) {
7638-
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7639-
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7640-
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7641-
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7642-
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7643-
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7644-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7645-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7646-
7647-
// use the output as the source for the next token-wise iterations
7648-
if (i2 > 0) { s0 = s; }
7649-
7650-
// d_inner
7651-
for (int i1 = 0; i1 < ir; ++i1) {
7652-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7653-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654-
float x_dt = x[i1] * dt_soft_plus;
7655-
float sumf = 0.0f;
7656-
// d_state
7657-
for (int i0 = 0; i0 < nc; ++i0) {
7658-
int i = i0 + i1*nc;
7659-
// state = prev_state * dA + dB * x
7660-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7661-
// y = rowwise_dotprod(state, C)
7662-
sumf += state * C[i0];
7663-
s[i] = state;
7636+
#ifdef __ARM_FEATURE_SVE
7637+
for (int i3 = 0; i3 < n_s; ++i3) {
7638+
for (int i2 = 0; i2 < n_t; ++i2) {
7639+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7640+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7641+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7642+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7643+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7644+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7645+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7646+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7647+
7648+
// use the output as the source for the next token-wise iterations
7649+
if (i2 > 0) { s0 = s; }
7650+
7651+
// d_inner
7652+
for (int i1 = 0; i1 < ir; ++i1) {
7653+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7654+
float x_dt = x[i1] * dt_soft_plus;
7655+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
7656+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
7657+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7658+
7659+
for (int64_t k = 0; k < nc; k += svcntw()) {
7660+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
7661+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
7662+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
7663+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
7664+
7665+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
7666+
t1 = exp_ps_sve(svptrue_b32(), t1);
7667+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
7668+
7669+
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
7670+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7671+
7672+
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
7673+
}
7674+
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
76647675
}
7665-
y[i1] = sumf;
76667676
}
76677677
}
7668-
}
7678+
#else
7679+
for (int i3 = 0; i3 < n_s; ++i3) {
7680+
for (int i2 = 0; i2 < n_t; ++i2) {
7681+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7682+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7683+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7684+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7685+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7686+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7687+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7688+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7689+
7690+
// use the output as the source for the next token-wise iterations
7691+
if (i2 > 0) { s0 = s; }
7692+
7693+
// d_inner
7694+
for (int i1 = 0; i1 < ir; ++i1) {
7695+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7696+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7697+
float x_dt = x[i1] * dt_soft_plus;
7698+
float sumf = 0.0f;
7699+
// d_state
7700+
for (int i0 = 0; i0 < nc; ++i0) {
7701+
int i = i0 + i1*nc;
7702+
// state = prev_state * dA + dB * x
7703+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7704+
// y = rowwise_dotprod(state, C)
7705+
sumf += state * C[i0];
7706+
s[i] = state;
7707+
}
7708+
y[i1] = sumf;
7709+
}
7710+
}
7711+
}
7712+
#endif
76697713
}
76707714

76717715
void ggml_compute_forward_ssm_scan(

‎ggml/src/ggml-cpu/vec.h

Copy file name to clipboardExpand all lines: ggml/src/ggml-cpu/vec.h
+36Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
647647
#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
648648
#endif
649649

650+
/* Below function was borrowed from the GitHub repository:
651+
https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
652+
#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
653+
inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
654+
// Constants
655+
const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
656+
const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
657+
const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
658+
const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
659+
const svfloat32_t one = svdup_n_f32(1.0f);
660+
const svfloat32_t inactive1 = svdup_n_f32(0.0f);
661+
const svint32_t inactive2 = svdup_n_s32(0);
662+
663+
// Algorithm starts here
664+
svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
665+
svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
666+
svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
667+
668+
t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
669+
t1 = svadd_f32_m(pg, t1, one); // b = a + 1
670+
671+
svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
672+
svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
673+
t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
674+
675+
// and_(t2.d, t1.d, not_mask17.d)
676+
svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
677+
t5 = svsub_f32_m(pg, t1, t5); // z
678+
t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
679+
t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
680+
t0 = svmul_f32_m(pg, t0, t4); // Final result
681+
682+
return t0;
683+
}
684+
#endif
685+
650686
#if defined(__ARM_NEON) && defined(__aarch64__)
651687

652688
// adapted from arm limited optimized routine

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.