@@ -28,11 +28,13 @@ NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a)
28
28
// Based on ARM doc, see https://developer.arm.com/documentation/dui0204/j/CIHDIACI
29
29
NPY_FINLINE npyv_f32 npyv_sqrt_f32 (npyv_f32 a )
30
30
{
31
+ const npyv_f32 one = vdupq_n_f32 (1.0f );
31
32
const npyv_f32 zero = vdupq_n_f32 (0.0f );
32
33
const npyv_u32 pinf = vdupq_n_u32 (0x7f800000 );
33
34
npyv_u32 is_zero = vceqq_f32 (a , zero ), is_inf = vceqq_u32 (vreinterpretq_u32_f32 (a ), pinf );
34
- // guard against floating-point division-by-zero error
35
- npyv_f32 guard_byz = vbslq_f32 (is_zero , vreinterpretq_f32_u32 (pinf ), a );
35
+ npyv_u32 is_special = vorrq_u32 (is_zero , is_inf );
36
+ // guard against division-by-zero and infinity input to vrsqrte to avoid invalid fp error
37
+ npyv_f32 guard_byz = vbslq_f32 (is_special , one , a );
36
38
// estimate to (1/√a)
37
39
npyv_f32 rsqrte = vrsqrteq_f32 (guard_byz );
38
40
/**
@@ -47,10 +49,8 @@ NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a)
47
49
rsqrte = vmulq_f32 (vrsqrtsq_f32 (vmulq_f32 (a , rsqrte ), rsqrte ), rsqrte );
48
50
// a * (1/√a)
49
51
npyv_f32 sqrt = vmulq_f32 (a , rsqrte );
50
- // return zero if the a is zero
51
- // - return zero if a is zero.
52
- // - return positive infinity if a is positive infinity
53
- return vbslq_f32 (vorrq_u32 (is_zero , is_inf ), a , sqrt );
52
+ // Handle special cases: return a for zeros and positive infinities
53
+ return vbslq_f32 (is_special , a , sqrt );
54
54
}
55
55
#endif // NPY_SIMD_F64
56
56
0 commit comments