Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cb22dd3

Browse filesBrowse files
author
pengxu
committed
Loongarch: add lsx functions
1 parent 455256c commit cb22dd3
Copy full SHA for cb22dd3

File tree

9 files changed

+2078
-0
lines changed
Filter options

9 files changed

+2078
-0
lines changed
+312Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
#ifndef NPY_SIMD
2+
#error "Not a standalone header"
3+
#endif
4+
5+
#ifndef _NPY_SIMD_LSX_ARITHMETIC_H
6+
#define _NPY_SIMD_LSX_ARITHMETIC_H
7+
8+
/***************************
9+
* Addition
10+
***************************/
11+
// non-saturated
12+
#define npyv_add_u8 __lsx_vadd_b
13+
#define npyv_add_s8 __lsx_vadd_b
14+
#define npyv_add_u16 __lsx_vadd_h
15+
#define npyv_add_s16 __lsx_vadd_h
16+
#define npyv_add_u32 __lsx_vadd_w
17+
#define npyv_add_s32 __lsx_vadd_w
18+
#define npyv_add_u64 __lsx_vadd_d
19+
#define npyv_add_s64 __lsx_vadd_d
20+
#define npyv_add_f32 __lsx_vfadd_s
21+
#define npyv_add_f64 __lsx_vfadd_d
22+
23+
// saturated
24+
#define npyv_adds_u8 __lsx_vsadd_bu
25+
#define npyv_adds_s8 __lsx_vsadd_b
26+
#define npyv_adds_u16 __lsx_vsadd_hu
27+
#define npyv_adds_s16 __lsx_vsadd_h
28+
// TODO: rest, after implement Packs intrins
29+
#define npyv_adds_u32 __lsx_vsadd_wu
30+
#define npyv_adds_s32 __lsx_vsadd_w
31+
#define npyv_adds_u64 __lsx_vsadd_du
32+
#define npyv_adds_s64 __lsx_vsadd_d
33+
34+
35+
/***************************
36+
* Subtraction
37+
***************************/
38+
// non-saturated
39+
#define npyv_sub_u8 __lsx_vsub_b
40+
#define npyv_sub_s8 __lsx_vsub_b
41+
#define npyv_sub_u16 __lsx_vsub_h
42+
#define npyv_sub_s16 __lsx_vsub_h
43+
#define npyv_sub_u32 __lsx_vsub_w
44+
#define npyv_sub_s32 __lsx_vsub_w
45+
#define npyv_sub_u64 __lsx_vsub_d
46+
#define npyv_sub_s64 __lsx_vsub_d
47+
#define npyv_sub_f32 __lsx_vfsub_s
48+
#define npyv_sub_f64 __lsx_vfsub_d
49+
50+
// saturated
51+
#define npyv_subs_u8 __lsx_vssub_bu
52+
#define npyv_subs_s8 __lsx_vssub_b
53+
#define npyv_subs_u16 __lsx_vssub_hu
54+
#define npyv_subs_s16 __lsx_vssub_h
55+
#define npyv_subs_u32 __lsx_vssub_wu
56+
#define npyv_subs_s32 __lsx_vssub_w
57+
#define npyv_subs_u64 __lsx_vssub_du
58+
#define npyv_subs_s64 __lsx_vssub_d
59+
60+
61+
/***************************
62+
* Multiplication
63+
***************************/
64+
// non-saturated
65+
#define npyv_mul_u8 __lsx_vmul_b
66+
#define npyv_mul_s8 __lsx_vmul_b
67+
#define npyv_mul_u16 __lsx_vmul_h
68+
#define npyv_mul_s16 __lsx_vmul_h
69+
#define npyv_mul_u32 __lsx_vmul_w
70+
#define npyv_mul_s32 __lsx_vmul_w
71+
#define npyv_mul_f32 __lsx_vfmul_s
72+
#define npyv_mul_f64 __lsx_vfmul_d
73+
74+
75+
/***************************
76+
* Integer Division
77+
***************************/
78+
// See simd/intdiv.h for more clarification
79+
// divide each unsigned 8-bit element by a precomputed divisor
80+
NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
81+
{
82+
const __m128i bmask = __lsx_vreplgr2vr_w(0x00FF00FF);
83+
const __m128i shf1b = __lsx_vreplgr2vr_b(0xFFU >> __lsx_vpickve2gr_w(divisor.val[1], 0));
84+
const __m128i shf2b = __lsx_vreplgr2vr_b(0xFFU >> __lsx_vpickve2gr_w(divisor.val[2], 0));
85+
// high part of unsigned multiplication
86+
__m128i mulhi_even = __lsx_vmul_h(__lsx_vand_v(a, bmask), divisor.val[0]);
87+
__m128i mulhi_odd = __lsx_vmul_h(__lsx_vsrli_h(a, 8), divisor.val[0]);
88+
mulhi_even = __lsx_vsrli_h(mulhi_even, 8);
89+
__m128i mulhi = npyv_select_u8(bmask, mulhi_even, mulhi_odd);
90+
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
91+
__m128i q = __lsx_vsub_b(a, mulhi);
92+
q = __lsx_vand_v(__lsx_vsrl_h(q, divisor.val[1]), shf1b);
93+
q = __lsx_vadd_b(mulhi, q);
94+
q = __lsx_vand_v(__lsx_vsrl_h(q, divisor.val[2]), shf2b);
95+
return q;
96+
}
97+
// divide each signed 8-bit element by a precomputed divisor (round towards zero)
98+
NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor);
99+
NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor)
100+
{
101+
const __m128i bmask = __lsx_vreplgr2vr_w(0x00FF00FF);
102+
// instead of _mm_cvtepi8_epi16/_mm_packs_epi16 to wrap around overflow
103+
__m128i divc_even = npyv_divc_s16(__lsx_vsrai_h(__lsx_vslli_h(a, 8), 8), divisor);
104+
__m128i divc_odd = npyv_divc_s16(__lsx_vsrai_h(a, 8), divisor);
105+
divc_odd = __lsx_vslli_h(divc_odd, 8);
106+
return npyv_select_u8(bmask, divc_even, divc_odd);
107+
}
108+
// divide each unsigned 16-bit element by a precomputed divisor
109+
NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor)
110+
{
111+
// high part of unsigned multiplication
112+
__m128i mulhi = __lsx_vmuh_hu(a, divisor.val[0]);
113+
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
114+
__m128i q = __lsx_vsub_h(a, mulhi);
115+
q = __lsx_vsrl_h(q, divisor.val[1]);
116+
q = __lsx_vadd_h(mulhi, q);
117+
q = __lsx_vsrl_h(q, divisor.val[2]);
118+
return q;
119+
}
120+
// divide each signed 16-bit element by a precomputed divisor (round towards zero)
121+
NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor)
122+
{
123+
// high part of signed multiplication
124+
__m128i mulhi = __lsx_vmuh_h(a, divisor.val[0]);
125+
// q = ((a + mulhi) >> sh1) - XSIGN(a)
126+
// trunc(a/d) = (q ^ dsign) - dsign
127+
__m128i q = __lsx_vsra_h(__lsx_vadd_h(a, mulhi), divisor.val[1]);
128+
q = __lsx_vsub_h(q, __lsx_vsrai_h(a, 15));
129+
q = __lsx_vsub_h(__lsx_vxor_v(q, divisor.val[2]), divisor.val[2]);
130+
return q;
131+
}
132+
// divide each unsigned 32-bit element by a precomputed divisor
133+
NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor)
134+
{
135+
// high part of unsigned multiplication
136+
__m128i mulhi_even = __lsx_vsrli_d(__lsx_vmulwev_d_wu(a, divisor.val[0]), 32);
137+
__m128i mulhi_odd = __lsx_vmulwev_d_wu(__lsx_vsrli_d(a, 32), divisor.val[0]);
138+
mulhi_odd = __lsx_vand_v(mulhi_odd, (__m128i)(v4i32){0, -1, 0, -1});
139+
__m128i mulhi = __lsx_vor_v(mulhi_even, mulhi_odd);
140+
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
141+
__m128i q = __lsx_vsub_w(a, mulhi);
142+
q = __lsx_vsrl_w(q, divisor.val[1]);
143+
q = __lsx_vadd_w(mulhi, q);
144+
q = __lsx_vsrl_w(q, divisor.val[2]);
145+
return q;
146+
}
147+
// divide each signed 32-bit element by a precomputed divisor (round towards zero)
148+
NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor)
149+
{
150+
__m128i asign = __lsx_vsrai_w(a, 31);
151+
__m128i mulhi_even = __lsx_vsrli_d(__lsx_vmulwev_d_wu(a, divisor.val[0]), 32);
152+
__m128i mulhi_odd = __lsx_vmulwev_d_wu(__lsx_vsrli_d(a, 32), divisor.val[0]);
153+
mulhi_odd = __lsx_vand_v(mulhi_odd, (__m128i)(v4i32){0, -1, 0, -1});
154+
__m128i mulhi = __lsx_vor_v(mulhi_even, mulhi_odd);
155+
// convert unsigned to signed high multiplication
156+
// mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0);
157+
const __m128i msign= __lsx_vsrai_w(divisor.val[0], 31);
158+
__m128i m_asign = __lsx_vand_v(divisor.val[0], asign);
159+
__m128i a_msign = __lsx_vand_v(a, msign);
160+
mulhi = __lsx_vsub_w(mulhi, m_asign);
161+
mulhi = __lsx_vsub_w(mulhi, a_msign);
162+
// q = ((a + mulhi) >> sh1) - XSIGN(a)
163+
// trunc(a/d) = (q ^ dsign) - dsign
164+
__m128i q = __lsx_vsra_w(__lsx_vadd_w(a, mulhi), divisor.val[1]);
165+
q = __lsx_vsub_w(q, asign);
166+
q = __lsx_vsub_w(__lsx_vxor_v(q, divisor.val[2]), divisor.val[2]);
167+
return q;
168+
}
169+
// returns the high 64 bits of unsigned 64-bit multiplication
170+
// xref https://stackoverflow.com/a/28827013
171+
NPY_FINLINE npyv_u64 npyv__mullhi_u64(npyv_u64 a, npyv_u64 b)
172+
{
173+
__m128i lomask = npyv_setall_s64(0xffffffff);
174+
__m128i a_hi = __lsx_vsrli_d(a, 32); // a0l, a0h, a1l, a1h
175+
__m128i b_hi = __lsx_vsrli_d(b, 32); // b0l, b0h, b1l, b1h
176+
// compute partial products
177+
__m128i w0 = __lsx_vmulwev_d_wu(a, b); // a0l*b0l, a1l*b1l
178+
__m128i w1 = __lsx_vmulwev_d_wu(a, b_hi); // a0l*b0h, a1l*b1h
179+
__m128i w2 = __lsx_vmulwev_d_wu(a_hi, b); // a0h*b0l, a1h*b0l
180+
__m128i w3 = __lsx_vmulwev_d_wu(a_hi, b_hi); // a0h*b0h, a1h*b1h
181+
// sum partial products
182+
__m128i w0h = __lsx_vsrli_d(w0, 32);
183+
__m128i s1 = __lsx_vadd_d(w1, w0h);
184+
__m128i s1l = __lsx_vand_v(s1, lomask);
185+
__m128i s1h = __lsx_vsrli_d(s1, 32);
186+
187+
__m128i s2 = __lsx_vadd_d(w2, s1l);
188+
__m128i s2h = __lsx_vsrli_d(s2, 32);
189+
190+
__m128i hi = __lsx_vadd_d(w3, s1h);
191+
hi = __lsx_vadd_d(hi, s2h);
192+
return hi;
193+
}
194+
// divide each unsigned 64-bit element by a precomputed divisor
195+
NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor)
196+
{
197+
// high part of unsigned multiplication
198+
__m128i mulhi = npyv__mullhi_u64(a, divisor.val[0]);
199+
// floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
200+
__m128i q = __lsx_vsub_d(a, mulhi);
201+
q = __lsx_vsrl_d(q, divisor.val[1]);
202+
q = __lsx_vadd_d(mulhi, q);
203+
q = __lsx_vsrl_d(q, divisor.val[2]);
204+
return q;
205+
}
206+
// divide each signed 64-bit element by a precomputed divisor (round towards zero)
207+
NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor)
208+
{
209+
// high part of unsigned multiplication
210+
__m128i mulhi = npyv__mullhi_u64(a, divisor.val[0]);
211+
// convert unsigned to signed high multiplication
212+
// mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0);
213+
const __m128i msign= __lsx_vslt_d(divisor.val[0], __lsx_vldi(0));
214+
__m128i asign = __lsx_vslt_d(a, __lsx_vldi(0));
215+
__m128i m_asign = __lsx_vand_v(divisor.val[0], asign);
216+
__m128i a_msign = __lsx_vand_v(a, msign);
217+
mulhi = __lsx_vsub_d(mulhi, m_asign);
218+
mulhi = __lsx_vsub_d(mulhi, a_msign);
219+
// q = (a + mulhi) >> sh
220+
__m128i q = __lsx_vadd_d(a, mulhi);
221+
// emulate arithmetic right shift
222+
const __m128i sigb = npyv_setall_s64(1LL << 63);
223+
q = __lsx_vsrl_d(__lsx_vadd_d(q, sigb), divisor.val[1]);
224+
q = __lsx_vsub_d(q, __lsx_vsrl_d(sigb, divisor.val[1]));
225+
// q = q - XSIGN(a)
226+
// trunc(a/d) = (q ^ dsign) - dsign
227+
q = __lsx_vsub_d(q, asign);
228+
q = __lsx_vsub_d(__lsx_vxor_v(q, divisor.val[2]), divisor.val[2]);
229+
return q;
230+
}
231+
/***************************
232+
* Division
233+
***************************/
234+
#define npyv_div_f32 __lsx_vfdiv_s
235+
#define npyv_div_f64 __lsx_vfdiv_d
236+
/***************************
237+
* FUSED
238+
***************************/
239+
// multiply and add, a*b + c
240+
#define npyv_muladd_f32 __lsx_vfmadd_s
241+
#define npyv_muladd_f64 __lsx_vfmadd_d
242+
// multiply and subtract, a*b - c
243+
#define npyv_mulsub_f32 __lsx_vfmsub_s
244+
#define npyv_mulsub_f64 __lsx_vfmsub_d
245+
// negate multiply and add, -(a*b) + c equal to -(a*b - c)
246+
#define npyv_nmuladd_f32 __lsx_vfnmsub_s
247+
#define npyv_nmuladd_f64 __lsx_vfnmsub_d
248+
// negate multiply and subtract, -(a*b) - c equal to -(a*b +c)
249+
#define npyv_nmulsub_f32 __lsx_vfnmadd_s
250+
#define npyv_nmulsub_f64 __lsx_vfnmadd_d
251+
// multiply, add for odd elements and subtract even elements.
252+
// (a * b) -+ c
253+
NPY_FINLINE npyv_f32 npyv_muladdsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c)
254+
{
255+
return __lsx_vfmadd_s(a, b, (__m128)__lsx_vxor_v((__m128i)c, (__m128i)(v4f32){-0.0, 0.0, -0.0, 0.0}));
256+
257+
}
258+
NPY_FINLINE npyv_f64 npyv_muladdsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c)
259+
{
260+
return __lsx_vfmadd_d(a, b, (__m128d)__lsx_vxor_v((__m128i)c, (__m128i)(v2f64){-0.0, 0.0}));
261+
262+
}
263+
264+
/***************************
265+
* Summation
266+
***************************/
267+
// reduce sum across vector
268+
NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a)
269+
{
270+
__m128i t1 = __lsx_vhaddw_du_wu(a, a);
271+
__m128i t2 = __lsx_vhaddw_qu_du(t1, t1);
272+
return __lsx_vpickve2gr_wu(t2, 0);
273+
}
274+
275+
NPY_FINLINE npy_uint64 npyv_sum_u64(npyv_u64 a)
276+
{
277+
__m128i t = __lsx_vhaddw_qu_du(a, a);
278+
return __lsx_vpickve2gr_du(t, 0);
279+
}
280+
281+
NPY_FINLINE float npyv_sum_f32(npyv_f32 a)
282+
{
283+
__m128 ft = __lsx_vfadd_s(a, (__m128)__lsx_vbsrl_v((__m128i)a, 8));
284+
ft = __lsx_vfadd_s(ft, (__m128)__lsx_vbsrl_v(ft, 4));
285+
return ft[0];
286+
}
287+
288+
NPY_FINLINE double npyv_sum_f64(npyv_f64 a)
289+
{
290+
__m128d fd = __lsx_vfadd_d(a, (__m128d)__lsx_vreplve_d((__m128i)a, 1));
291+
return fd[0];
292+
}
293+
294+
// expand the source vector and performs sum reduce
295+
NPY_FINLINE npy_uint16 npyv_sumup_u8(npyv_u8 a)
296+
{
297+
__m128i first = __lsx_vhaddw_hu_bu((__m128i)a,(__m128i)a);
298+
__m128i second = __lsx_vhaddw_wu_hu((__m128i)first,(__m128i)first);
299+
__m128i third = __lsx_vhaddw_du_wu((__m128i)second,(__m128i)second);
300+
__m128i four = __lsx_vhaddw_qu_du((__m128i)third,(__m128i)third);
301+
return four[0];
302+
}
303+
304+
NPY_FINLINE npy_uint32 npyv_sumup_u16(npyv_u16 a)
305+
{
306+
__m128i t1 = __lsx_vhaddw_wu_hu(a, a);
307+
__m128i t2 = __lsx_vhaddw_du_wu(t1, t1);
308+
__m128i t3 = __lsx_vhaddw_qu_du(t2, t2);
309+
return __lsx_vpickve2gr_w(t3, 0);
310+
}
311+
312+
#endif // _NPY_SIMD_LSX_ARITHMETIC_H

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.