Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

ENH: speed up matmul for non-contiguous operands #23588 #23752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions 38 benchmarks/benchmarks/bench_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,41 @@ def time_transpose(self, shape, npdtypes):

def time_vdot(self, shape, npdtypes):
np.vdot(self.xarg, self.x2arg)


class MatmulStrided(Benchmark):
# some interesting points selected from
# https://github.com/numpy/numpy/pull/23752#issuecomment-2629521597
# (m, p, n, batch_size)
args = [
(2, 2, 2, 1), (2, 2, 2, 10), (5, 5, 5, 1), (5, 5, 5, 10),
(10, 10, 10, 1), (10, 10, 10, 10), (20, 20, 20, 1), (20, 20, 20, 10),
(50, 50, 50, 1), (50, 50, 50, 10),
(150, 150, 100, 1), (150, 150, 100, 10),
(400, 400, 100, 1), (400, 400, 100, 10)
]

param_names = ['configuration']

def __init__(self):
self.args_map = {
'matmul_m%03d_p%03d_n%03d_bs%02d' % arg: arg for arg in self.args
}

self.params = [list(self.args_map.keys())]

def setup(self, configuration):
m, p, n, batch_size = self.args_map[configuration]

self.a1raw = np.random.rand(batch_size * m * 2 * n).reshape(
(batch_size, m, 2 * n)
)

self.a1 = self.a1raw[:, :, ::2]

self.a2 = np.random.rand(batch_size * n * p).reshape(
(batch_size, n, p)
)

def time_matmul(self, configuration):
return np.matmul(self.a1, self.a2)
170 changes: 145 additions & 25 deletions 170 numpy/_core/src/umath/matmul.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,47 @@ static const npy_cfloat oneF = 1.0f, zeroF = 0.0f;
* #step1 = 1.F, 1., &oneF, &oneD#
* #step0 = 0.F, 0., &zeroF, &zeroD#
*/

static inline void
@name@_matrix_copy(npy_bool transpose,
void *_ip, npy_intp is_m, npy_intp is_n,
void *_op, npy_intp os_m, npy_intp os_n,
npy_intp dm, npy_intp dn)
{

char *ip = (char *)_ip, *op = (char *)_op;

npy_intp m, n, ib, ob;

if (transpose) {
ib = is_m * dm, ob = os_m * dm;

for (n = 0; n < dn; n++) {
for (m = 0; m < dm; m++) {
*(@ctype@ *)op = *(@ctype@ *)ip;
ip += is_m;
op += os_m;
}
ip += is_n - ib;
op += os_n - ob;
}

return;
}

ib = is_n * dn, ob = os_n * dn;

for (m = 0; m < dm; m++) {
for (n = 0; n < dn; n++) {
*(@ctype@ *)op = *(@ctype@ *)ip;
ip += is_n;
op += os_n;
}
ip += is_m - ib;
op += os_m - ob;
}
}

NPY_NO_EXPORT void
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
void *ip2, npy_intp is2_n,
Expand Down Expand Up @@ -429,10 +470,43 @@ NPY_NO_EXPORT void
npy_bool i2blasable = i2_c_blasable || i2_f_blasable;
npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz);
npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz);
npy_bool oblasable = o_c_blasable || o_f_blasable;
npy_bool vector_matrix = ((dm == 1) && i2blasable &&
is_blasable2d(is1_n, sz, dn, 1, sz));
npy_bool matrix_vector = ((dp == 1) && i1blasable &&
is_blasable2d(is2_n, sz, dn, 1, sz));
npy_bool noblas_fallback = too_big_for_blas || any_zero_dim;
npy_bool matrix_matrix = !noblas_fallback && !special_case;
npy_bool allocate_buffer = matrix_matrix && (
!i1blasable || !i2blasable || !oblasable
);

uint8_t *tmp_ip12op = NULL;
void *tmp_ip1 = NULL, *tmp_ip2 = NULL, *tmp_op = NULL;

if (allocate_buffer){
npy_intp ip1_size = i1blasable ? 0 : sz * dm * dn,
ip2_size = i2blasable ? 0 : sz * dn * dp,
op_size = oblasable ? 0 : sz * dm * dp,
total_size = ip1_size + ip2_size + op_size;

tmp_ip12op = (uint8_t*)malloc(total_size);

if (tmp_ip12op == NULL) {
PyGILState_STATE gil_state = PyGILState_Ensure();
PyErr_SetString(
PyExc_MemoryError, "Out of memory in matmul"
);
PyGILState_Release(gil_state);

return;
}

tmp_ip1 = tmp_ip12op;
tmp_ip2 = tmp_ip12op + ip1_size;
tmp_op = tmp_ip12op + ip1_size + ip2_size;
}

#endif

for (iOuter = 0; iOuter < dOuter; iOuter++,
Expand All @@ -444,7 +518,7 @@ NPY_NO_EXPORT void
* PyUFunc_MatmulLoopSelector. But that call does not have access to
* n, m, p and strides.
*/
if (too_big_for_blas || any_zero_dim) {
if (noblas_fallback) {
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p, dm, dn, dp);
Expand Down Expand Up @@ -478,30 +552,73 @@ NPY_NO_EXPORT void
op, os_m, os_p, dm, dn, dp);
}
} else {
/* matrix @ matrix */
if (i1blasable && i2blasable && o_c_blasable) {
@TYPE@_matmul_matrixmatrix(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p,
dm, dn, dp);
} else if (i1blasable && i2blasable && o_f_blasable) {
/*
* Use transpose equivalence:
* matmul(a, b, o) == matmul(b.T, a.T, o.T)
*/
@TYPE@_matmul_matrixmatrix(ip2, is2_p, is2_n,
ip1, is1_n, is1_m,
op, os_p, os_m,
dp, dn, dm);
} else {
/*
* If parameters are castable to int and we copy the
* non-blasable (or non-ccontiguous output)
* we could still use BLAS, see gh-12365.
*/
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p, dm, dn, dp);
/* matrix @ matrix
* copy if not blasable, see gh-12365 & gh-23588 */
npy_bool i1_transpose = is1_m < is1_n,
i2_transpose = is2_n < is2_p,
o_transpose = os_m < os_p;

npy_intp tmp_is1_m = i1_transpose ? sz : sz*dn,
tmp_is1_n = i1_transpose ? sz*dm : sz,
tmp_is2_n = i2_transpose ? sz : sz*dp,
tmp_is2_p = i2_transpose ? sz*dn : sz,
tmp_os_m = o_transpose ? sz : sz*dp,
tmp_os_p = o_transpose ? sz*dm : sz;

if (!i1blasable) {
@TYPE@_matrix_copy(
i1_transpose, ip1, is1_m, is1_n,
tmp_ip1, tmp_is1_m, tmp_is1_n,
dm, dn
);
}

if (!i2blasable) {
@TYPE@_matrix_copy(
i2_transpose, ip2, is2_n, is2_p,
tmp_ip2, tmp_is2_n, tmp_is2_p,
dn, dp
);
}

void *ip1_ = i1blasable ? ip1 : tmp_ip1,
*ip2_ = i2blasable ? ip2 : tmp_ip2,
*op_ = oblasable ? op : tmp_op;

npy_intp is1_m_ = i1blasable ? is1_m : tmp_is1_m,
is1_n_ = i1blasable ? is1_n : tmp_is1_n,
is2_n_ = i2blasable ? is2_n : tmp_is2_n,
is2_p_ = i2blasable ? is2_p : tmp_is2_p,
os_m_ = oblasable ? os_m : tmp_os_m,
os_p_ = oblasable ? os_p : tmp_os_p;

/*
* Use transpose equivalence:
* matmul(a, b, o) == matmul(b.T, a.T, o.T)
*/
if (o_f_blasable) {
@TYPE@_matmul_matrixmatrix(
ip2_, is2_p_, is2_n_,
ip1_, is1_n_, is1_m_,
op_, os_p_, os_m_,
dp, dn, dm
);
}
else {
@TYPE@_matmul_matrixmatrix(
ip1_, is1_m_, is1_n_,
ip2_, is2_n_, is2_p_,
op_, os_m_, os_p_,
dm, dn, dp
);
}

if(!oblasable){
@TYPE@_matrix_copy(
o_transpose, tmp_op, tmp_os_m, tmp_os_p,
op, os_m, os_p,
dm, dp
);
}
}
#else
Expand All @@ -511,6 +628,9 @@ NPY_NO_EXPORT void

#endif
}
#if @USEBLAS@ && defined(HAVE_CBLAS)
if (allocate_buffer) free(tmp_ip12op);
#endif
}

/**end repeat**/
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.