@@ -79,6 +79,26 @@ static const npy_cfloat oneF = 1.0f, zeroF = 0.0f;
79
79
* #step1 = 1.F, 1., &oneF, &oneD#
80
80
* #step0 = 0.F, 0., &zeroF, &zeroD#
81
81
*/
82
+
83
+ static inline void
84
+ @name @_matrix_copy (void * _ip , npy_intp is_m , npy_intp is_n ,
85
+ void * _op , npy_intp os_m , npy_intp os_n ,
86
+ npy_intp dm , npy_intp dn )
87
+ {
88
+ npy_intp m , n , ib_n = is_n * dn , ob_n = os_n * dn ;
89
+ char * ip = (char * )_ip , * op = (char * )_op ;
90
+
91
+ for (m = 0 ; m < dm ; m ++ ) {
92
+ for (n = 0 ; n < dn ; n ++ ) {
93
+ * (@ctype @ * )op = * (@ctype @ * )ip ;
94
+ ip += is_n ;
95
+ op += os_n ;
96
+ }
97
+ ip += is_m - ib_n ;
98
+ op += os_m - ob_n ;
99
+ }
100
+ }
101
+
82
102
NPY_NO_EXPORT void
83
103
@name @_gemv (void * ip1 , npy_intp is1_m , npy_intp is1_n ,
84
104
void * ip2 , npy_intp is2_n , npy_intp NPY_UNUSED (is2_p ),
@@ -433,6 +453,11 @@ NPY_NO_EXPORT void
433
453
is_blasable2d (is1_n , sz , dn , 1 , sz ));
434
454
npy_bool matrix_vector = ((dp == 1 ) && i1blasable &&
435
455
is_blasable2d (is2_n , sz , dn , 1 , sz ));
456
+
457
+ void * tmp_ip1 = NULL ;
458
+ void * tmp_ip2 = NULL ;
459
+ void * tmp_op = NULL ;
460
+
436
461
#endif
437
462
438
463
for (iOuter = 0 ; iOuter < dOuter ; iOuter ++ ,
@@ -500,9 +525,46 @@ NPY_NO_EXPORT void
500
525
* non-blasable (or non-ccontiguous output)
501
526
* we could still use BLAS, see gh-12365.
502
527
*/
528
+ if (tmp_ip1 == NULL ) {
529
+ tmp_ip1 = malloc (sz * dm * dn );
530
+ }
531
+
532
+ if (tmp_ip2 == NULL ) {
533
+ tmp_ip2 = malloc (sz * dn * dp );
534
+ }
535
+
536
+ if (tmp_op == NULL ) {
537
+ tmp_op = malloc (sz * dm * dp );
538
+ }
539
+
540
+ /* Not enough memory */
541
+ if (
542
+ tmp_ip1 == NULL || tmp_ip2 == NULL || tmp_op == NULL
543
+ ) {
503
544
@TYPE @_matmul_inner_noblas (ip1 , is1_m , is1_n ,
504
545
ip2 , is2_n , is2_p ,
505
546
op , os_m , os_p , dm , dn , dp );
547
+ }
548
+ else {
549
+ @TYPE @_matrix_copy (
550
+ ip1 , is1_m , is1_n , tmp_ip1 , sz * dn , sz , dm , dn
551
+ );
552
+
553
+ @TYPE @_matrix_copy (
554
+ ip2 , is2_n , is2_p , tmp_ip2 , sz * dp , sz , dn , dp
555
+ );
556
+
557
+ @TYPE @_matmul_matrixmatrix (
558
+ tmp_ip1 , sz * dn , sz ,
559
+ tmp_ip2 , sz * dp , sz ,
560
+ tmp_op , sz * dp , sz ,
561
+ dm , dn , dp
562
+ );
563
+
564
+ @TYPE @_matrix_copy (
565
+ tmp_op , sz * dp , sz , op , os_m , os_p , dm , dp
566
+ );
567
+ }
506
568
}
507
569
}
508
570
#else
@@ -512,6 +574,11 @@ NPY_NO_EXPORT void
512
574
513
575
#endif
514
576
}
577
+ #if @USEBLAS @ && defined(HAVE_CBLAS )
578
+ if (tmp_ip1 != NULL ) free (tmp_ip1 );
579
+ if (tmp_ip2 != NULL ) free (tmp_ip2 );
580
+ if (tmp_op != NULL ) free (tmp_op );
581
+ #endif
515
582
}
516
583
517
584
/**end repeat**/
0 commit comments