From 53f911d0723f2e8a319644aee14d18781cf35868 Mon Sep 17 00:00:00 2001 From: Michael Siebert Date: Fri, 21 Feb 2025 20:57:13 +0100 Subject: [PATCH] speed up matmul #23588 --- benchmarks/benchmarks/bench_linalg.py | 38 ++++++ numpy/_core/src/umath/matmul.c.src | 170 ++++++++++++++++++++++---- 2 files changed, 183 insertions(+), 25 deletions(-) diff --git a/benchmarks/benchmarks/bench_linalg.py b/benchmarks/benchmarks/bench_linalg.py index feaf1bab1fb9..e3a4be70276e 100644 --- a/benchmarks/benchmarks/bench_linalg.py +++ b/benchmarks/benchmarks/bench_linalg.py @@ -217,3 +217,41 @@ def time_transpose(self, shape, npdtypes): def time_vdot(self, shape, npdtypes): np.vdot(self.xarg, self.x2arg) + + +class MatmulStrided(Benchmark): + # some interesting points selected from + # https://github.com/numpy/numpy/pull/23752#issuecomment-2629521597 + # (m, p, n, batch_size) + args = [ + (2, 2, 2, 1), (2, 2, 2, 10), (5, 5, 5, 1), (5, 5, 5, 10), + (10, 10, 10, 1), (10, 10, 10, 10), (20, 20, 20, 1), (20, 20, 20, 10), + (50, 50, 50, 1), (50, 50, 50, 10), + (150, 150, 100, 1), (150, 150, 100, 10), + (400, 400, 100, 1), (400, 400, 100, 10) + ] + + param_names = ['configuration'] + + def __init__(self): + self.args_map = { + 'matmul_m%03d_p%03d_n%03d_bs%02d' % arg: arg for arg in self.args + } + + self.params = [list(self.args_map.keys())] + + def setup(self, configuration): + m, p, n, batch_size = self.args_map[configuration] + + self.a1raw = np.random.rand(batch_size * m * 2 * n).reshape( + (batch_size, m, 2 * n) + ) + + self.a1 = self.a1raw[:, :, ::2] + + self.a2 = np.random.rand(batch_size * n * p).reshape( + (batch_size, n, p) + ) + + def time_matmul(self, configuration): + return np.matmul(self.a1, self.a2) diff --git a/numpy/_core/src/umath/matmul.c.src b/numpy/_core/src/umath/matmul.c.src index f0f8b2f4153f..d9be7b1d6826 100644 --- a/numpy/_core/src/umath/matmul.c.src +++ b/numpy/_core/src/umath/matmul.c.src @@ -79,6 +79,47 @@ static const npy_cfloat oneF = 1.0f, zeroF = 0.0f; * #step1 = 1.F, 1., &oneF, &oneD# * #step0 = 0.F, 0., &zeroF, &zeroD# */ + +static inline void +@name@_matrix_copy(npy_bool transpose, + void *_ip, npy_intp is_m, npy_intp is_n, + void *_op, npy_intp os_m, npy_intp os_n, + npy_intp dm, npy_intp dn) +{ + + char *ip = (char *)_ip, *op = (char *)_op; + + npy_intp m, n, ib, ob; + + if (transpose) { + ib = is_m * dm, ob = os_m * dm; + + for (n = 0; n < dn; n++) { + for (m = 0; m < dm; m++) { + *(@ctype@ *)op = *(@ctype@ *)ip; + ip += is_m; + op += os_m; + } + ip += is_n - ib; + op += os_n - ob; + } + + return; + } + + ib = is_n * dn, ob = os_n * dn; + + for (m = 0; m < dm; m++) { + for (n = 0; n < dn; n++) { + *(@ctype@ *)op = *(@ctype@ *)ip; + ip += is_n; + op += os_n; + } + ip += is_m - ib; + op += os_m - ob; + } +} + NPY_NO_EXPORT void @name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n, void *ip2, npy_intp is2_n, @@ -429,10 +470,43 @@ NPY_NO_EXPORT void npy_bool i2blasable = i2_c_blasable || i2_f_blasable; npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz); npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz); + npy_bool oblasable = o_c_blasable || o_f_blasable; npy_bool vector_matrix = ((dm == 1) && i2blasable && is_blasable2d(is1_n, sz, dn, 1, sz)); npy_bool matrix_vector = ((dp == 1) && i1blasable && is_blasable2d(is2_n, sz, dn, 1, sz)); + npy_bool noblas_fallback = too_big_for_blas || any_zero_dim; + npy_bool matrix_matrix = !noblas_fallback && !special_case; + npy_bool allocate_buffer = matrix_matrix && ( + !i1blasable || !i2blasable || !oblasable + ); + + uint8_t *tmp_ip12op = NULL; + void *tmp_ip1 = NULL, *tmp_ip2 = NULL, *tmp_op = NULL; + + if (allocate_buffer){ + npy_intp ip1_size = i1blasable ? 0 : sz * dm * dn, + ip2_size = i2blasable ? 0 : sz * dn * dp, + op_size = oblasable ? 0 : sz * dm * dp, + total_size = ip1_size + ip2_size + op_size; + + tmp_ip12op = (uint8_t*)malloc(total_size); + + if (tmp_ip12op == NULL) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyErr_SetString( + PyExc_MemoryError, "Out of memory in matmul" + ); + PyGILState_Release(gil_state); + + return; + } + + tmp_ip1 = tmp_ip12op; + tmp_ip2 = tmp_ip12op + ip1_size; + tmp_op = tmp_ip12op + ip1_size + ip2_size; + } + #endif for (iOuter = 0; iOuter < dOuter; iOuter++, @@ -444,7 +518,7 @@ NPY_NO_EXPORT void * PyUFunc_MatmulLoopSelector. But that call does not have access to * n, m, p and strides. */ - if (too_big_for_blas || any_zero_dim) { + if (noblas_fallback) { @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n, ip2, is2_n, is2_p, op, os_m, os_p, dm, dn, dp); @@ -478,30 +552,73 @@ NPY_NO_EXPORT void op, os_m, os_p, dm, dn, dp); } } else { - /* matrix @ matrix */ - if (i1blasable && i2blasable && o_c_blasable) { - @TYPE@_matmul_matrixmatrix(ip1, is1_m, is1_n, - ip2, is2_n, is2_p, - op, os_m, os_p, - dm, dn, dp); - } else if (i1blasable && i2blasable && o_f_blasable) { - /* - * Use transpose equivalence: - * matmul(a, b, o) == matmul(b.T, a.T, o.T) - */ - @TYPE@_matmul_matrixmatrix(ip2, is2_p, is2_n, - ip1, is1_n, is1_m, - op, os_p, os_m, - dp, dn, dm); - } else { - /* - * If parameters are castable to int and we copy the - * non-blasable (or non-ccontiguous output) - * we could still use BLAS, see gh-12365. - */ - @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n, - ip2, is2_n, is2_p, - op, os_m, os_p, dm, dn, dp); + /* matrix @ matrix + * copy if not blasable, see gh-12365 & gh-23588 */ + npy_bool i1_transpose = is1_m < is1_n, + i2_transpose = is2_n < is2_p, + o_transpose = os_m < os_p; + + npy_intp tmp_is1_m = i1_transpose ? sz : sz*dn, + tmp_is1_n = i1_transpose ? sz*dm : sz, + tmp_is2_n = i2_transpose ? sz : sz*dp, + tmp_is2_p = i2_transpose ? sz*dn : sz, + tmp_os_m = o_transpose ? sz : sz*dp, + tmp_os_p = o_transpose ? sz*dm : sz; + + if (!i1blasable) { + @TYPE@_matrix_copy( + i1_transpose, ip1, is1_m, is1_n, + tmp_ip1, tmp_is1_m, tmp_is1_n, + dm, dn + ); + } + + if (!i2blasable) { + @TYPE@_matrix_copy( + i2_transpose, ip2, is2_n, is2_p, + tmp_ip2, tmp_is2_n, tmp_is2_p, + dn, dp + ); + } + + void *ip1_ = i1blasable ? ip1 : tmp_ip1, + *ip2_ = i2blasable ? ip2 : tmp_ip2, + *op_ = oblasable ? op : tmp_op; + + npy_intp is1_m_ = i1blasable ? is1_m : tmp_is1_m, + is1_n_ = i1blasable ? is1_n : tmp_is1_n, + is2_n_ = i2blasable ? is2_n : tmp_is2_n, + is2_p_ = i2blasable ? is2_p : tmp_is2_p, + os_m_ = oblasable ? os_m : tmp_os_m, + os_p_ = oblasable ? os_p : tmp_os_p; + + /* + * Use transpose equivalence: + * matmul(a, b, o) == matmul(b.T, a.T, o.T) + */ + if (o_f_blasable) { + @TYPE@_matmul_matrixmatrix( + ip2_, is2_p_, is2_n_, + ip1_, is1_n_, is1_m_, + op_, os_p_, os_m_, + dp, dn, dm + ); + } + else { + @TYPE@_matmul_matrixmatrix( + ip1_, is1_m_, is1_n_, + ip2_, is2_n_, is2_p_, + op_, os_m_, os_p_, + dm, dn, dp + ); + } + + if(!oblasable){ + @TYPE@_matrix_copy( + o_transpose, tmp_op, tmp_os_m, tmp_os_p, + op, os_m, os_p, + dm, dp + ); } } #else @@ -511,6 +628,9 @@ NPY_NO_EXPORT void #endif } +#if @USEBLAS@ && defined(HAVE_CBLAS) + if (allocate_buffer) free(tmp_ip12op); +#endif } /**end repeat**/