@@ -49,6 +49,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
49
49
#include "convert_datatype.h"
50
50
#include "conversion_utils.h"
51
51
#include "nditer_pywrap.h"
52
+ #define NPY_ITERATOR_IMPLEMENTATION_CODE
53
+ #include "nditer_impl.h"
52
54
#include "methods.h"
53
55
#include "_datetime.h"
54
56
#include "datetime_strings.h"
@@ -67,6 +69,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
67
69
#include "mem_overlap.h"
68
70
#include "typeinfo.h"
69
71
#include "convert.h" /* for PyArray_AssignZero */
72
+ #include "lowlevel_strided_loops.h"
73
+ #include "dtype_transfer.h"
70
74
71
75
#include "get_attr_string.h"
72
76
#include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */
@@ -3381,6 +3385,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
3381
3385
return NULL ;
3382
3386
}
3383
3387
3388
+ NPY_cast_info cast_info = {.func = NULL };
3389
+
3384
3390
ax = (PyArrayObject * )PyArray_FROM_O (x );
3385
3391
ay = (PyArrayObject * )PyArray_FROM_O (y );
3386
3392
if (ax == NULL || ay == NULL ) {
@@ -3394,14 +3400,15 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
3394
3400
};
3395
3401
npy_uint32 op_flags [4 ] = {
3396
3402
NPY_ITER_WRITEONLY | NPY_ITER_ALLOCATE | NPY_ITER_NO_SUBTYPE ,
3397
- NPY_ITER_READONLY , NPY_ITER_READONLY , NPY_ITER_READONLY
3403
+ NPY_ITER_READONLY ,
3404
+ NPY_ITER_READONLY | NPY_ITER_ALIGNED ,
3405
+ NPY_ITER_READONLY | NPY_ITER_ALIGNED
3398
3406
};
3399
3407
PyArray_Descr * common_dt = PyArray_ResultType (2 , & op_in [0 ] + 2 ,
3400
3408
0 , NULL );
3401
3409
PyArray_Descr * op_dt [4 ] = {common_dt , PyArray_DescrFromType (NPY_BOOL ),
3402
3410
common_dt , common_dt };
3403
3411
NpyIter * iter ;
3404
- int needs_api ;
3405
3412
NPY_BEGIN_THREADS_DEF ;
3406
3413
3407
3414
if (common_dt == NULL || op_dt [1 ] == NULL ) {
@@ -3418,61 +3425,94 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
3418
3425
goto fail ;
3419
3426
}
3420
3427
3421
- needs_api = NpyIter_IterationNeedsAPI (iter );
3422
-
3423
3428
/* Get the result from the iterator object array */
3424
3429
ret = (PyObject * )NpyIter_GetOperandArray (iter )[0 ];
3425
3430
3426
- NPY_BEGIN_THREADS_NDITER (iter );
3431
+ npy_intp itemsize = common_dt -> elsize ;
3432
+
3433
+ int has_ref = PyDataType_REFCHK (common_dt );
3434
+
3435
+ NPY_ARRAYMETHOD_FLAGS transfer_flags = 0 ;
3436
+
3437
+ npy_intp transfer_strides [2 ] = {itemsize , itemsize };
3438
+ npy_intp one = 1 ;
3439
+
3440
+ if (has_ref || ((itemsize != 16 ) && (itemsize != 8 ) && (itemsize != 4 ) &&
3441
+ (itemsize != 2 ) && (itemsize != 1 ))) {
3442
+ // The iterator has NPY_ITER_ALIGNED flag so no need to check alignment
3443
+ // of the input arrays.
3444
+ //
3445
+ // There's also no need to set up a cast for y, since the iterator
3446
+ // ensures both casts are identical.
3447
+ if (PyArray_GetDTypeTransferFunction (
3448
+ 1 , itemsize , itemsize , common_dt , common_dt , 0 ,
3449
+ & cast_info , & transfer_flags ) != NPY_SUCCEED ) {
3450
+ goto fail ;
3451
+ }
3452
+ }
3453
+
3454
+ transfer_flags = PyArrayMethod_COMBINED_FLAGS (
3455
+ transfer_flags , NpyIter_GetTransferFlags (iter ));
3456
+
3457
+ if (!(transfer_flags & NPY_METH_REQUIRES_PYAPI )) {
3458
+ NPY_BEGIN_THREADS_THRESHOLDED (NpyIter_GetIterSize (iter ));
3459
+ }
3427
3460
3428
3461
if (NpyIter_GetIterSize (iter ) != 0 ) {
3429
3462
NpyIter_IterNextFunc * iternext = NpyIter_GetIterNext (iter , NULL );
3430
3463
npy_intp * innersizeptr = NpyIter_GetInnerLoopSizePtr (iter );
3431
3464
char * * dataptrarray = NpyIter_GetDataPtrArray (iter );
3465
+ npy_intp * strides = NpyIter_GetInnerStrideArray (iter );
3432
3466
3433
3467
do {
3434
- PyArray_Descr * dtx = NpyIter_GetDescrArray (iter )[2 ];
3435
- PyArray_Descr * dty = NpyIter_GetDescrArray (iter )[3 ];
3436
- int axswap = PyDataType_ISBYTESWAPPED (dtx );
3437
- int ayswap = PyDataType_ISBYTESWAPPED (dty );
3438
- PyArray_CopySwapFunc * copyswapx = dtx -> f -> copyswap ;
3439
- PyArray_CopySwapFunc * copyswapy = dty -> f -> copyswap ;
3440
- int native = (axswap == ayswap ) && (axswap == 0 ) && !needs_api ;
3441
3468
npy_intp n = (* innersizeptr );
3442
- npy_intp itemsize = NpyIter_GetDescrArray (iter )[0 ]-> elsize ;
3443
- npy_intp cstride = NpyIter_GetInnerStrideArray (iter )[1 ];
3444
- npy_intp xstride = NpyIter_GetInnerStrideArray (iter )[2 ];
3445
- npy_intp ystride = NpyIter_GetInnerStrideArray (iter )[3 ];
3446
3469
char * dst = dataptrarray [0 ];
3447
3470
char * csrc = dataptrarray [1 ];
3448
3471
char * xsrc = dataptrarray [2 ];
3449
3472
char * ysrc = dataptrarray [3 ];
3450
3473
3474
+ // the iterator might mutate these pointers,
3475
+ // so need to update them every iteration
3476
+ npy_intp cstride = strides [1 ];
3477
+ npy_intp xstride = strides [2 ];
3478
+ npy_intp ystride = strides [3 ];
3479
+
3451
3480
/* constant sizes so compiler replaces memcpy */
3452
- if (native && itemsize == 16 ) {
3481
+ if (! has_ref && itemsize == 16 ) {
3453
3482
INNER_WHERE_LOOP (16 );
3454
3483
}
3455
- else if (native && itemsize == 8 ) {
3484
+ else if (! has_ref && itemsize == 8 ) {
3456
3485
INNER_WHERE_LOOP (8 );
3457
3486
}
3458
- else if (native && itemsize == 4 ) {
3487
+ else if (! has_ref && itemsize == 4 ) {
3459
3488
INNER_WHERE_LOOP (4 );
3460
3489
}
3461
- else if (native && itemsize == 2 ) {
3490
+ else if (! has_ref && itemsize == 2 ) {
3462
3491
INNER_WHERE_LOOP (2 );
3463
3492
}
3464
- else if (native && itemsize == 1 ) {
3493
+ else if (! has_ref && itemsize == 1 ) {
3465
3494
INNER_WHERE_LOOP (1 );
3466
3495
}
3467
3496
else {
3468
- /* copyswap is faster than memcpy even if we are native */
3469
3497
npy_intp i ;
3470
3498
for (i = 0 ; i < n ; i ++ ) {
3471
3499
if (* csrc ) {
3472
- copyswapx (dst , xsrc , axswap , ret );
3500
+ char * args [2 ] = {xsrc , dst };
3501
+
3502
+ if (cast_info .func (
3503
+ & cast_info .context , args , & one ,
3504
+ transfer_strides , cast_info .auxdata ) < 0 ) {
3505
+ goto fail ;
3506
+ }
3473
3507
}
3474
3508
else {
3475
- copyswapy (dst , ysrc , ayswap , ret );
3509
+ char * args [2 ] = {ysrc , dst };
3510
+
3511
+ if (cast_info .func (
3512
+ & cast_info .context , args , & one ,
3513
+ transfer_strides , cast_info .auxdata ) < 0 ) {
3514
+ goto fail ;
3515
+ }
3476
3516
}
3477
3517
dst += itemsize ;
3478
3518
xsrc += xstride ;
@@ -3489,6 +3529,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
3489
3529
Py_DECREF (arr );
3490
3530
Py_DECREF (ax );
3491
3531
Py_DECREF (ay );
3532
+ NPY_cast_info_xfree (& cast_info );
3492
3533
3493
3534
if (NpyIter_Deallocate (iter ) != NPY_SUCCEED ) {
3494
3535
Py_DECREF (ret );
@@ -3502,6 +3543,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
3502
3543
Py_DECREF (arr );
3503
3544
Py_XDECREF (ax );
3504
3545
Py_XDECREF (ay );
3546
+ NPY_cast_info_xfree (& cast_info );
3505
3547
return NULL ;
3506
3548
}
3507
3549
0 commit comments