@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
7633
7633
const int ir1 = MIN (ir0 + dr, nr);
7634
7634
const int ir = ir1 - ir0;
7635
7635
7636
- for (int i3 = 0 ; i3 < n_s; ++i3) {
7637
- for (int i2 = 0 ; i2 < n_t ; ++i2) {
7638
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7639
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7640
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7641
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7642
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7643
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7644
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7645
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7646
-
7647
- // use the output as the source for the next token-wise iterations
7648
- if (i2 > 0 ) { s0 = s; }
7649
-
7650
- // d_inner
7651
- for (int i1 = 0 ; i1 < ir; ++i1) {
7652
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7653
- float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7654
- float x_dt = x[i1] * dt_soft_plus;
7655
- float sumf = 0 .0f ;
7656
- // d_state
7657
- for (int i0 = 0 ; i0 < nc; ++i0) {
7658
- int i = i0 + i1*nc;
7659
- // state = prev_state * dA + dB * x
7660
- float state = (s0[i] * expf (dt_soft_plus * A[i])) + (B[i0] * x_dt);
7661
- // y = rowwise_dotprod(state, C)
7662
- sumf += state * C[i0];
7663
- s[i] = state;
7636
+ #ifdef __ARM_FEATURE_SVE
7637
+ for (int i3 = 0 ; i3 < n_s; ++i3) {
7638
+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
7639
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7640
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7641
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7642
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7643
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7644
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7645
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7646
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7647
+
7648
+ // use the output as the source for the next token-wise iterations
7649
+ if (i2 > 0 ) { s0 = s; }
7650
+
7651
+ // d_inner
7652
+ for (int i1 = 0 ; i1 < ir; ++i1) {
7653
+ float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7654
+ float x_dt = x[i1] * dt_soft_plus;
7655
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1 (x_dt);
7656
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1 (dt_soft_plus);
7657
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7658
+
7659
+ for (int64_t k = 0 ; k < nc; k += svcntw ()) {
7660
+ svfloat32_t vA = GGML_F32_VEC_LOAD (&A[i1*nc + k]);
7661
+ svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k]);
7662
+ svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k]);
7663
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD (&s0[i1*nc + k]);
7664
+
7665
+ svfloat32_t t1 = GGML_F32_VEC_MUL (vdt_soft_plus, vA);
7666
+ t1 = exp_ps_sve (svptrue_b32 (), t1);
7667
+ svfloat32_t t2 = GGML_F32_VEC_MUL (vx_dt, vB);
7668
+
7669
+ vs0 = GGML_F32_VEC_FMA (vs0, t1, t2);
7670
+ r1_vector = GGML_F32_VEC_ADD (GGML_F32_VEC_MUL (vs0, vC), r1_vector);
7671
+
7672
+ GGML_F32_VEC_STORE (&s[i1*nc + k], vs0);
7673
+ }
7674
+ y[i1] = GGML_F32xt_REDUCE_ONE (r1_vector);
7664
7675
}
7665
- y[i1] = sumf;
7666
7676
}
7667
7677
}
7668
- }
7678
+ #else
7679
+ for (int i3 = 0 ; i3 < n_s; ++i3) {
7680
+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
7681
+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7682
+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7683
+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7684
+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7685
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7686
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7687
+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7688
+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7689
+
7690
+ // use the output as the source for the next token-wise iterations
7691
+ if (i2 > 0 ) { s0 = s; }
7692
+
7693
+ // d_inner
7694
+ for (int i1 = 0 ; i1 < ir; ++i1) {
7695
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7696
+ float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7697
+ float x_dt = x[i1] * dt_soft_plus;
7698
+ float sumf = 0 .0f ;
7699
+ // d_state
7700
+ for (int i0 = 0 ; i0 < nc; ++i0) {
7701
+ int i = i0 + i1*nc;
7702
+ // state = prev_state * dA + dB * x
7703
+ float state = (s0[i] * expf (dt_soft_plus * A[i])) + (B[i0] * x_dt);
7704
+ // y = rowwise_dotprod(state, C)
7705
+ sumf += state * C[i0];
7706
+ s[i] = state;
7707
+ }
7708
+ y[i1] = sumf;
7709
+ }
7710
+ }
7711
+ }
7712
+ #endif
7669
7713
}
7670
7714
7671
7715
void ggml_compute_forward_ssm_scan (
0 commit comments