Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 78c00d9

Browse filesBrowse files
committed
speed up matmul #23588
1 parent e59c074 commit 78c00d9
Copy full SHA for 78c00d9

File tree

Expand file treeCollapse file tree

1 file changed

+67
-0
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+67
-0
lines changed

‎numpy/_core/src/umath/matmul.c.src

Copy file name to clipboardExpand all lines: numpy/_core/src/umath/matmul.c.src
+67Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,26 @@ static const npy_cfloat oneF = 1.0f, zeroF = 0.0f;
7979
* #step1 = 1.F, 1., &oneF, &oneD#
8080
* #step0 = 0.F, 0., &zeroF, &zeroD#
8181
*/
82+
83+
static inline void
84+
@name@_matrix_copy(void *_ip, npy_intp is_m, npy_intp is_n,
85+
void *_op, npy_intp os_m, npy_intp os_n,
86+
npy_intp dm, npy_intp dn)
87+
{
88+
npy_intp m, n, ib_n = is_n*dn, ob_n = os_n*dn;
89+
char *ip = (char *)_ip, *op = (char *)_op;
90+
91+
for (m = 0; m < dm; m++) {
92+
for (n = 0; n < dn; n++) {
93+
*(@ctype@ *)op = *(@ctype@ *)ip;
94+
ip += is_n;
95+
op += os_n;
96+
}
97+
ip += is_m - ib_n;
98+
op += os_m - ob_n;
99+
}
100+
}
101+
82102
NPY_NO_EXPORT void
83103
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
84104
void *ip2, npy_intp is2_n, npy_intp NPY_UNUSED(is2_p),
@@ -433,6 +453,11 @@ NPY_NO_EXPORT void
433453
is_blasable2d(is1_n, sz, dn, 1, sz));
434454
npy_bool matrix_vector = ((dp == 1) && i1blasable &&
435455
is_blasable2d(is2_n, sz, dn, 1, sz));
456+
457+
void *tmp_ip1 = NULL;
458+
void *tmp_ip2 = NULL;
459+
void *tmp_op = NULL;
460+
436461
#endif
437462

438463
for (iOuter = 0; iOuter < dOuter; iOuter++,
@@ -500,9 +525,46 @@ NPY_NO_EXPORT void
500525
* non-blasable (or non-ccontiguous output)
501526
* we could still use BLAS, see gh-12365.
502527
*/
528+
if(tmp_ip1 == NULL) {
529+
tmp_ip1 = malloc(sz * dm * dn);
530+
}
531+
532+
if(tmp_ip2 == NULL) {
533+
tmp_ip2 = malloc(sz * dn * dp);
534+
}
535+
536+
if(tmp_op == NULL) {
537+
tmp_op = malloc(sz * dm * dp);
538+
}
539+
540+
/* Not enough memory */
541+
if(
542+
tmp_ip1 == NULL || tmp_ip2 == NULL || tmp_op == NULL
543+
) {
503544
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
504545
ip2, is2_n, is2_p,
505546
op, os_m, os_p, dm, dn, dp);
547+
}
548+
else {
549+
@TYPE@_matrix_copy(
550+
ip1, is1_m, is1_n, tmp_ip1, sz*dn, sz, dm, dn
551+
);
552+
553+
@TYPE@_matrix_copy(
554+
ip2, is2_n, is2_p, tmp_ip2, sz*dp, sz, dn, dp
555+
);
556+
557+
@TYPE@_matmul_matrixmatrix(
558+
tmp_ip1, sz*dn, sz,
559+
tmp_ip2, sz*dp, sz,
560+
tmp_op, sz*dp, sz,
561+
dm, dn, dp
562+
);
563+
564+
@TYPE@_matrix_copy(
565+
tmp_op, sz*dp, sz, op, os_m, os_p, dm, dp
566+
);
567+
}
506568
}
507569
}
508570
#else
@@ -512,6 +574,11 @@ NPY_NO_EXPORT void
512574

513575
#endif
514576
}
577+
#if @USEBLAS@ && defined(HAVE_CBLAS)
578+
if(tmp_ip1 != NULL) free(tmp_ip1);
579+
if(tmp_ip2 != NULL) free(tmp_ip2);
580+
if(tmp_op != NULL) free(tmp_op);
581+
#endif
515582
}
516583

517584
/**end repeat**/

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.