@@ -121,6 +121,19 @@ default_resolve_descriptors(
121
121
}
122
122
123
123
124
+ NPY_INLINE static int
125
+ is_contiguous (
126
+ npy_intp const * strides , PyArray_Descr * const * descriptors , int nargs )
127
+ {
128
+ for (int i = 0 ; i < nargs ; i ++ ) {
129
+ if (strides [i ] != descriptors [i ]-> elsize ) {
130
+ return 0 ;
131
+ }
132
+ }
133
+ return 1 ;
134
+ }
135
+
136
+
124
137
/**
125
138
* The default method to fetch the correct loop for a cast or ufunc
126
139
* (at the time of writing only casts).
@@ -138,18 +151,36 @@ default_resolve_descriptors(
138
151
* @param flags
139
152
* @return 0 on success -1 on failure.
140
153
*/
141
- static int
142
- default_get_strided_loop (
143
- PyArrayMethod_Context * NPY_UNUSED (context ),
144
- int NPY_UNUSED (aligned ), int NPY_UNUSED (move_references ),
145
- npy_intp * NPY_UNUSED (strides ),
146
- PyArray_StridedUnaryOp * * NPY_UNUSED (out_loop ),
147
- NpyAuxData * * NPY_UNUSED (out_transferdata ),
148
- NPY_ARRAYMETHOD_FLAGS * NPY_UNUSED (flags ))
154
+ NPY_NO_EXPORT int
155
+ npy_default_get_strided_loop (
156
+ PyArrayMethod_Context * context ,
157
+ int aligned , int NPY_UNUSED (move_references ), npy_intp * strides ,
158
+ PyArray_StridedUnaryOp * * out_loop , NpyAuxData * * out_transferdata ,
159
+ NPY_ARRAYMETHOD_FLAGS * flags )
149
160
{
150
- PyErr_SetString (PyExc_NotImplementedError ,
151
- "default loop getter is not implemented" );
152
- return -1 ;
161
+ PyArray_Descr * * descrs = context -> descriptors ;
162
+ PyArrayMethodObject * meth = context -> method ;
163
+ * flags = meth -> flags & NPY_METH_RUNTIME_FLAGS ;
164
+ * out_transferdata = NULL ;
165
+
166
+ int nargs = meth -> nin + meth -> nout ;
167
+ if (aligned ) {
168
+ if (meth -> contiguous_loop == NULL ||
169
+ !is_contiguous (strides , descrs , nargs )) {
170
+ * out_loop = meth -> strided_loop ;
171
+ return 0 ;
172
+ }
173
+ * out_loop = meth -> contiguous_loop ;
174
+ }
175
+ else {
176
+ if (meth -> unaligned_contiguous_loop == NULL ||
177
+ !is_contiguous (strides , descrs , nargs )) {
178
+ * out_loop = meth -> unaligned_strided_loop ;
179
+ return 0 ;
180
+ }
181
+ * out_loop = meth -> unaligned_contiguous_loop ;
182
+ }
183
+ return 0 ;
153
184
}
154
185
155
186
@@ -225,7 +256,7 @@ fill_arraymethod_from_slots(
225
256
PyArrayMethodObject * meth = res -> method ;
226
257
227
258
/* Set the defaults */
228
- meth -> get_strided_loop = & default_get_strided_loop ;
259
+ meth -> get_strided_loop = & npy_default_get_strided_loop ;
229
260
meth -> resolve_descriptors = & default_resolve_descriptors ;
230
261
231
262
/* Fill in the slots passed by the user */
@@ -295,7 +326,7 @@ fill_arraymethod_from_slots(
295
326
}
296
327
}
297
328
}
298
- if (meth -> get_strided_loop != & default_get_strided_loop ) {
329
+ if (meth -> get_strided_loop != & npy_default_get_strided_loop ) {
299
330
/* Do not check the actual loop fields. */
300
331
return 0 ;
301
332
}
@@ -468,6 +499,9 @@ boundarraymethod_dealloc(PyObject *self)
468
499
* May raise an error, but usually should not.
469
500
* The function validates the casting attribute compared to the returned
470
501
* casting level.
502
+ *
503
+ * TODO: This function is not public API, and certain code paths will need
504
+ * changes and especially testing if they were to be made public.
471
505
*/
472
506
static PyObject *
473
507
boundarraymethod__resolve_descripors (
@@ -481,7 +515,7 @@ boundarraymethod__resolve_descripors(
481
515
482
516
if (!PyTuple_CheckExact (descr_tuple ) ||
483
517
PyTuple_Size (descr_tuple ) != nin + nout ) {
484
- PyErr_Format (PyExc_ValueError ,
518
+ PyErr_Format (PyExc_TypeError ,
485
519
"_resolve_descriptors() takes exactly one tuple with as many "
486
520
"elements as the method takes arguments (%d+%d)." , nin , nout );
487
521
return NULL ;
@@ -494,15 +528,15 @@ boundarraymethod__resolve_descripors(
494
528
}
495
529
else if (tmp == Py_None ) {
496
530
if (i < nin ) {
497
- PyErr_SetString (PyExc_ValueError ,
531
+ PyErr_SetString (PyExc_TypeError ,
498
532
"only output dtypes may be omitted (set to None)." );
499
533
return NULL ;
500
534
}
501
535
given_descrs [i ] = NULL ;
502
536
}
503
537
else if (PyArray_DescrCheck (tmp )) {
504
538
if (Py_TYPE (tmp ) != (PyTypeObject * )self -> dtypes [i ]) {
505
- PyErr_Format (PyExc_ValueError ,
539
+ PyErr_Format (PyExc_TypeError ,
506
540
"input dtype %S was not an exact instance of the bound "
507
541
"DType class %S." , tmp , self -> dtypes [i ]);
508
542
return NULL ;
@@ -580,9 +614,145 @@ boundarraymethod__resolve_descripors(
580
614
}
581
615
582
616
617
+ /*
618
+ * TODO: This function is not public API, and certain code paths will need
619
+ * changes and especially testing if they were to be made public.
620
+ */
621
+ static PyObject *
622
+ boundarraymethod__simple_strided_call (
623
+ PyBoundArrayMethodObject * self , PyObject * arr_tuple )
624
+ {
625
+ PyArrayObject * arrays [NPY_MAXARGS ];
626
+ PyArray_Descr * descrs [NPY_MAXARGS ];
627
+ PyArray_Descr * out_descrs [NPY_MAXARGS ];
628
+ ssize_t length = -1 ;
629
+ int aligned = 1 ;
630
+ npy_intp strides [NPY_MAXARGS ];
631
+ int nin = self -> method -> nin ;
632
+ int nout = self -> method -> nout ;
633
+
634
+ if (!PyTuple_CheckExact (arr_tuple ) ||
635
+ PyTuple_Size (arr_tuple ) != nin + nout ) {
636
+ PyErr_Format (PyExc_TypeError ,
637
+ "_simple_strided_call() takes exactly one tuple with as many "
638
+ "arrays as the method takes arguments (%d+%d)." , nin , nout );
639
+ return NULL ;
640
+ }
641
+
642
+ for (int i = 0 ; i < nin + nout ; i ++ ) {
643
+ PyObject * tmp = PyTuple_GetItem (arr_tuple , i );
644
+ if (tmp == NULL ) {
645
+ return NULL ;
646
+ }
647
+ else if (!PyArray_CheckExact (tmp )) {
648
+ PyErr_SetString (PyExc_TypeError ,
649
+ "All inputs must be NumPy arrays." );
650
+ return NULL ;
651
+ }
652
+ arrays [i ] = (PyArrayObject * )tmp ;
653
+ descrs [i ] = PyArray_DESCR (arrays [i ]);
654
+
655
+ /* Check that the input is compatible with a simple method call. */
656
+ if (Py_TYPE (descrs [i ]) != (PyTypeObject * )self -> dtypes [i ]) {
657
+ PyErr_Format (PyExc_TypeError ,
658
+ "input dtype %S was not an exact instance of the bound "
659
+ "DType class %S." , descrs [i ], self -> dtypes [i ]);
660
+ return NULL ;
661
+ }
662
+ if (PyArray_NDIM (arrays [i ]) != 1 ) {
663
+ PyErr_SetString (PyExc_ValueError ,
664
+ "All arrays must be one dimensional." );
665
+ return NULL ;
666
+ }
667
+ if (i == 0 ) {
668
+ length = PyArray_SIZE (arrays [i ]);
669
+ }
670
+ else if (PyArray_SIZE (arrays [i ]) != length ) {
671
+ PyErr_SetString (PyExc_ValueError ,
672
+ "All arrays must have the same length." );
673
+ return NULL ;
674
+ }
675
+ if (i >= nout ) {
676
+ if (PyArray_FailUnlessWriteable (
677
+ arrays [i ], "_simple_strided_call() output" ) < 0 ) {
678
+ return NULL ;
679
+ }
680
+ }
681
+
682
+ strides [i ] = PyArray_STRIDES (arrays [i ])[0 ];
683
+ /* TODO: We may need to distinguish aligned and itemsize-aligned */
684
+ aligned &= PyArray_ISALIGNED (arrays [i ]);
685
+ }
686
+ if (!aligned && !(self -> method -> flags & NPY_METH_SUPPORTS_UNALIGNED )) {
687
+ PyErr_SetString (PyExc_ValueError ,
688
+ "method does not support unaligned input." );
689
+ return NULL ;
690
+ }
691
+
692
+ NPY_CASTING casting = self -> method -> resolve_descriptors (
693
+ self -> method , self -> dtypes , descrs , out_descrs );
694
+
695
+ if (casting < 0 ) {
696
+ PyObject * err_type = NULL , * err_value = NULL , * err_traceback = NULL ;
697
+ PyErr_Fetch (& err_type , & err_value , & err_traceback );
698
+ PyErr_SetString (PyExc_TypeError ,
699
+ "cannot perform method call with the given dtypes." );
700
+ npy_PyErr_ChainExceptions (err_type , err_value , err_traceback );
701
+ return NULL ;
702
+ }
703
+
704
+ int dtypes_were_adapted = 0 ;
705
+ for (int i = 0 ; i < nin + nout ; i ++ ) {
706
+ /* NOTE: This check is probably much stricter than necessary... */
707
+ dtypes_were_adapted |= descrs [i ] != out_descrs [i ];
708
+ Py_DECREF (out_descrs [i ]);
709
+ }
710
+ if (dtypes_were_adapted ) {
711
+ PyErr_SetString (PyExc_TypeError ,
712
+ "_simple_strided_call(): requires dtypes to not require a cast "
713
+ "(must match exactly with `_resolve_descriptors()`)." );
714
+ return NULL ;
715
+ }
716
+
717
+ PyArrayMethod_Context context = {
718
+ .caller = NULL ,
719
+ .method = self -> method ,
720
+ .descriptors = descrs ,
721
+ };
722
+ PyArray_StridedUnaryOp * strided_loop = NULL ;
723
+ NpyAuxData * loop_data = NULL ;
724
+ NPY_ARRAYMETHOD_FLAGS flags = 0 ;
725
+
726
+ if (self -> method -> get_strided_loop (
727
+ & context , aligned , 0 , strides ,
728
+ & strided_loop , & loop_data , & flags ) < 0 ) {
729
+ return NULL ;
730
+ }
731
+
732
+ /*
733
+ * TODO: Add floating point error checks if requested and
734
+ * possibly release GIL if allowed by the flags.
735
+ */
736
+ /* TODO: strided_loop is currently a cast loop, this will change. */
737
+ int res = strided_loop (
738
+ PyArray_BYTES (arrays [1 ]), strides [1 ],
739
+ PyArray_BYTES (arrays [0 ]), strides [0 ],
740
+ length , descrs [0 ]-> elsize , loop_data );
741
+ if (loop_data != NULL ) {
742
+ loop_data -> free (loop_data );
743
+ }
744
+ if (res < 0 ) {
745
+ return NULL ;
746
+ }
747
+ Py_RETURN_NONE ;
748
+ }
749
+
750
+
583
751
PyMethodDef boundarraymethod_methods [] = {
584
752
{"_resolve_descriptors" , (PyCFunction )boundarraymethod__resolve_descripors ,
585
753
METH_O , "Resolve the given dtypes." },
754
+ {"_simple_strided_call" , (PyCFunction )boundarraymethod__simple_strided_call ,
755
+ METH_O , "call on 1-d inputs and pre-allocated outputs (single call)." },
586
756
{NULL , 0 , 0 , NULL },
587
757
};
588
758
0 commit comments