Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cebb7a6

Browse filesBrowse files
authored
Merge pull request #23770 from ngoldbaum/rm-copyswap-in-where
MAINT: do not use copyswap in where internals
2 parents 126b46c + 01a251b commit cebb7a6
Copy full SHA for cebb7a6

File tree

Expand file treeCollapse file tree

2 files changed

+73
-24
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+73
-24
lines changed

‎benchmarks/benchmarks/bench_function_base.py

Copy file name to clipboardExpand all lines: benchmarks/benchmarks/bench_function_base.py
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ def time_sort_worst(self):
308308
class Where(Benchmark):
309309
def setup(self):
310310
self.d = np.arange(20000)
311+
self.d_o = self.d.astype(object)
311312
self.e = self.d.copy()
313+
self.e_o = self.d_o.copy()
312314
self.cond = (self.d > 5000)
313315
size = 1024 * 1024 // 8
314316
rnd_array = np.random.rand(size)
@@ -332,6 +334,11 @@ def time_1(self):
332334
def time_2(self):
333335
np.where(self.cond, self.d, self.e)
334336

337+
def time_2_object(self):
338+
# object and byteswapped arrays have a
339+
# special slow path in the where internals
340+
np.where(self.cond, self.d_o, self.e_o)
341+
335342
def time_2_broadcast(self):
336343
np.where(self.cond, self.d, 0)
337344

‎numpy/core/src/multiarray/multiarraymodule.c

Copy file name to clipboardExpand all lines: numpy/core/src/multiarray/multiarraymodule.c
+66-24Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
4949
#include "convert_datatype.h"
5050
#include "conversion_utils.h"
5151
#include "nditer_pywrap.h"
52+
#define NPY_ITERATOR_IMPLEMENTATION_CODE
53+
#include "nditer_impl.h"
5254
#include "methods.h"
5355
#include "_datetime.h"
5456
#include "datetime_strings.h"
@@ -67,6 +69,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
6769
#include "mem_overlap.h"
6870
#include "typeinfo.h"
6971
#include "convert.h" /* for PyArray_AssignZero */
72+
#include "lowlevel_strided_loops.h"
73+
#include "dtype_transfer.h"
7074

7175
#include "get_attr_string.h"
7276
#include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */
@@ -3381,6 +3385,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
33813385
return NULL;
33823386
}
33833387

3388+
NPY_cast_info cast_info = {.func = NULL};
3389+
33843390
ax = (PyArrayObject*)PyArray_FROM_O(x);
33853391
ay = (PyArrayObject*)PyArray_FROM_O(y);
33863392
if (ax == NULL || ay == NULL) {
@@ -3394,14 +3400,15 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
33943400
};
33953401
npy_uint32 op_flags[4] = {
33963402
NPY_ITER_WRITEONLY | NPY_ITER_ALLOCATE | NPY_ITER_NO_SUBTYPE,
3397-
NPY_ITER_READONLY, NPY_ITER_READONLY, NPY_ITER_READONLY
3403+
NPY_ITER_READONLY,
3404+
NPY_ITER_READONLY | NPY_ITER_ALIGNED,
3405+
NPY_ITER_READONLY | NPY_ITER_ALIGNED
33983406
};
33993407
PyArray_Descr * common_dt = PyArray_ResultType(2, &op_in[0] + 2,
34003408
0, NULL);
34013409
PyArray_Descr * op_dt[4] = {common_dt, PyArray_DescrFromType(NPY_BOOL),
34023410
common_dt, common_dt};
34033411
NpyIter * iter;
3404-
int needs_api;
34053412
NPY_BEGIN_THREADS_DEF;
34063413

34073414
if (common_dt == NULL || op_dt[1] == NULL) {
@@ -3418,61 +3425,94 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
34183425
goto fail;
34193426
}
34203427

3421-
needs_api = NpyIter_IterationNeedsAPI(iter);
3422-
34233428
/* Get the result from the iterator object array */
34243429
ret = (PyObject*)NpyIter_GetOperandArray(iter)[0];
34253430

3426-
NPY_BEGIN_THREADS_NDITER(iter);
3431+
npy_intp itemsize = common_dt->elsize;
3432+
3433+
int has_ref = PyDataType_REFCHK(common_dt);
3434+
3435+
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
3436+
3437+
npy_intp transfer_strides[2] = {itemsize, itemsize};
3438+
npy_intp one = 1;
3439+
3440+
if (has_ref || ((itemsize != 16) && (itemsize != 8) && (itemsize != 4) &&
3441+
(itemsize != 2) && (itemsize != 1))) {
3442+
// The iterator has NPY_ITER_ALIGNED flag so no need to check alignment
3443+
// of the input arrays.
3444+
//
3445+
// There's also no need to set up a cast for y, since the iterator
3446+
// ensures both casts are identical.
3447+
if (PyArray_GetDTypeTransferFunction(
3448+
1, itemsize, itemsize, common_dt, common_dt, 0,
3449+
&cast_info, &transfer_flags) != NPY_SUCCEED) {
3450+
goto fail;
3451+
}
3452+
}
3453+
3454+
transfer_flags = PyArrayMethod_COMBINED_FLAGS(
3455+
transfer_flags, NpyIter_GetTransferFlags(iter));
3456+
3457+
if (!(transfer_flags & NPY_METH_REQUIRES_PYAPI)) {
3458+
NPY_BEGIN_THREADS_THRESHOLDED(NpyIter_GetIterSize(iter));
3459+
}
34273460

34283461
if (NpyIter_GetIterSize(iter) != 0) {
34293462
NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL);
34303463
npy_intp * innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);
34313464
char **dataptrarray = NpyIter_GetDataPtrArray(iter);
3465+
npy_intp *strides = NpyIter_GetInnerStrideArray(iter);
34323466

34333467
do {
3434-
PyArray_Descr * dtx = NpyIter_GetDescrArray(iter)[2];
3435-
PyArray_Descr * dty = NpyIter_GetDescrArray(iter)[3];
3436-
int axswap = PyDataType_ISBYTESWAPPED(dtx);
3437-
int ayswap = PyDataType_ISBYTESWAPPED(dty);
3438-
PyArray_CopySwapFunc *copyswapx = dtx->f->copyswap;
3439-
PyArray_CopySwapFunc *copyswapy = dty->f->copyswap;
3440-
int native = (axswap == ayswap) && (axswap == 0) && !needs_api;
34413468
npy_intp n = (*innersizeptr);
3442-
npy_intp itemsize = NpyIter_GetDescrArray(iter)[0]->elsize;
3443-
npy_intp cstride = NpyIter_GetInnerStrideArray(iter)[1];
3444-
npy_intp xstride = NpyIter_GetInnerStrideArray(iter)[2];
3445-
npy_intp ystride = NpyIter_GetInnerStrideArray(iter)[3];
34463469
char * dst = dataptrarray[0];
34473470
char * csrc = dataptrarray[1];
34483471
char * xsrc = dataptrarray[2];
34493472
char * ysrc = dataptrarray[3];
34503473

3474+
// the iterator might mutate these pointers,
3475+
// so need to update them every iteration
3476+
npy_intp cstride = strides[1];
3477+
npy_intp xstride = strides[2];
3478+
npy_intp ystride = strides[3];
3479+
34513480
/* constant sizes so compiler replaces memcpy */
3452-
if (native && itemsize == 16) {
3481+
if (!has_ref && itemsize == 16) {
34533482
INNER_WHERE_LOOP(16);
34543483
}
3455-
else if (native && itemsize == 8) {
3484+
else if (!has_ref && itemsize == 8) {
34563485
INNER_WHERE_LOOP(8);
34573486
}
3458-
else if (native && itemsize == 4) {
3487+
else if (!has_ref && itemsize == 4) {
34593488
INNER_WHERE_LOOP(4);
34603489
}
3461-
else if (native && itemsize == 2) {
3490+
else if (!has_ref && itemsize == 2) {
34623491
INNER_WHERE_LOOP(2);
34633492
}
3464-
else if (native && itemsize == 1) {
3493+
else if (!has_ref && itemsize == 1) {
34653494
INNER_WHERE_LOOP(1);
34663495
}
34673496
else {
3468-
/* copyswap is faster than memcpy even if we are native */
34693497
npy_intp i;
34703498
for (i = 0; i < n; i++) {
34713499
if (*csrc) {
3472-
copyswapx(dst, xsrc, axswap, ret);
3500+
char *args[2] = {xsrc, dst};
3501+
3502+
if (cast_info.func(
3503+
&cast_info.context, args, &one,
3504+
transfer_strides, cast_info.auxdata) < 0) {
3505+
goto fail;
3506+
}
34733507
}
34743508
else {
3475-
copyswapy(dst, ysrc, ayswap, ret);
3509+
char *args[2] = {ysrc, dst};
3510+
3511+
if (cast_info.func(
3512+
&cast_info.context, args, &one,
3513+
transfer_strides, cast_info.auxdata) < 0) {
3514+
goto fail;
3515+
}
34763516
}
34773517
dst += itemsize;
34783518
xsrc += xstride;
@@ -3489,6 +3529,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
34893529
Py_DECREF(arr);
34903530
Py_DECREF(ax);
34913531
Py_DECREF(ay);
3532+
NPY_cast_info_xfree(&cast_info);
34923533

34933534
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
34943535
Py_DECREF(ret);
@@ -3502,6 +3543,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
35023543
Py_DECREF(arr);
35033544
Py_XDECREF(ax);
35043545
Py_XDECREF(ay);
3546+
NPY_cast_info_xfree(&cast_info);
35053547
return NULL;
35063548
}
35073549

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.