Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

BUG: fixes for three related stringdtype issues #26436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 2 numpy/_core/src/multiarray/mapping.c
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ array_subscript(PyArrayObject *self, PyObject *op)

if (PyArray_GetDTypeTransferFunction(is_aligned,
itemsize, itemsize,
PyArray_DESCR(self), PyArray_DESCR(self),
PyArray_DESCR(self), PyArray_DESCR((PyArrayObject *)result),
0, &cast_info, &transfer_flags) != NPY_SUCCEED) {
goto finish;
}
Expand Down
80 changes: 52 additions & 28 deletions 80 numpy/_core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3258,7 +3258,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
return NULL;
}

NPY_cast_info cast_info = {.func = NULL};
NPY_cast_info x_cast_info = {.func = NULL};
NPY_cast_info y_cast_info = {.func = NULL};

ax = (PyArrayObject*)PyArray_FROM_O(x);
if (ax == NULL) {
Expand All @@ -3282,13 +3283,33 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
NPY_ITER_READONLY | NPY_ITER_ALIGNED,
NPY_ITER_READONLY | NPY_ITER_ALIGNED
};

common_dt = PyArray_ResultType(2, &op_in[2], 0, NULL);
if (common_dt == NULL) {
goto fail;
}
npy_intp itemsize = common_dt->elsize;

// If x and y don't have references, we ask the iterator to create buffers
// using the common data type of x and y and then do fast trivial copies
// in the loop below.
// Otherwise trivial copies aren't possible and we handle the cast item by item
// in the loop.
PyArray_Descr *x_dt, *y_dt;
int trivial_copy_loop = !PyDataType_REFCHK(common_dt) &&
((itemsize == 16) || (itemsize == 8) || (itemsize == 4) ||
(itemsize == 2) || (itemsize == 1));
if (trivial_copy_loop) {
x_dt = common_dt;
y_dt = common_dt;
}
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
else {
x_dt = PyArray_DESCR(op_in[2]);
y_dt = PyArray_DESCR(op_in[3]);
}
/* `PyArray_DescrFromType` cannot fail for simple builtin types: */
PyArray_Descr * op_dt[4] = {common_dt, PyArray_DescrFromType(NPY_BOOL),
common_dt, common_dt};
PyArray_Descr * op_dt[4] = {common_dt, PyArray_DescrFromType(NPY_BOOL), x_dt, y_dt};

NpyIter * iter;
NPY_BEGIN_THREADS_DEF;

Expand All @@ -3302,26 +3323,27 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)

/* Get the result from the iterator object array */
ret = (PyObject*)NpyIter_GetOperandArray(iter)[0];

npy_intp itemsize = common_dt->elsize;

int has_ref = PyDataType_REFCHK(common_dt);
PyArray_Descr *ret_dt = PyArray_DESCR((PyArrayObject *)ret);

NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;

npy_intp transfer_strides[2] = {itemsize, itemsize};
npy_intp x_strides[2] = {x_dt->elsize, itemsize};
npy_intp y_strides[2] = {y_dt->elsize, itemsize};
npy_intp one = 1;

if (has_ref || ((itemsize != 16) && (itemsize != 8) && (itemsize != 4) &&
(itemsize != 2) && (itemsize != 1))) {
if (!trivial_copy_loop) {
// The iterator has NPY_ITER_ALIGNED flag so no need to check alignment
// of the input arrays.
//
// There's also no need to set up a cast for y, since the iterator
// ensures both casts are identical.
if (PyArray_GetDTypeTransferFunction(
1, itemsize, itemsize, common_dt, common_dt, 0,
&cast_info, &transfer_flags) != NPY_SUCCEED) {
1, x_strides[0], x_strides[1],
PyArray_DESCR(op_in[2]), ret_dt, 0,
&x_cast_info, &transfer_flags) != NPY_SUCCEED) {
goto fail;
}
if (PyArray_GetDTypeTransferFunction(
1, y_strides[0], y_strides[1],
PyArray_DESCR(op_in[3]), ret_dt, 0,
&y_cast_info, &transfer_flags) != NPY_SUCCEED) {
goto fail;
}
}
Expand Down Expand Up @@ -3353,19 +3375,19 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
npy_intp ystride = strides[3];

/* constant sizes so compiler replaces memcpy */
if (!has_ref && itemsize == 16) {
if (trivial_copy_loop && itemsize == 16) {
INNER_WHERE_LOOP(16);
}
else if (!has_ref && itemsize == 8) {
else if (trivial_copy_loop && itemsize == 8) {
INNER_WHERE_LOOP(8);
}
else if (!has_ref && itemsize == 4) {
else if (trivial_copy_loop && itemsize == 4) {
INNER_WHERE_LOOP(4);
}
else if (!has_ref && itemsize == 2) {
else if (trivial_copy_loop && itemsize == 2) {
INNER_WHERE_LOOP(2);
}
else if (!has_ref && itemsize == 1) {
else if (trivial_copy_loop && itemsize == 1) {
INNER_WHERE_LOOP(1);
}
else {
Expand All @@ -3374,18 +3396,18 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
if (*csrc) {
char *args[2] = {xsrc, dst};

if (cast_info.func(
&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
if (x_cast_info.func(
&x_cast_info.context, args, &one,
x_strides, x_cast_info.auxdata) < 0) {
goto fail;
}
}
else {
char *args[2] = {ysrc, dst};

if (cast_info.func(
&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
if (y_cast_info.func(
&y_cast_info.context, args, &one,
y_strides, y_cast_info.auxdata) < 0) {
goto fail;
}
}
Expand All @@ -3405,7 +3427,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
Py_DECREF(ax);
Py_DECREF(ay);
Py_DECREF(common_dt);
NPY_cast_info_xfree(&cast_info);
NPY_cast_info_xfree(&x_cast_info);
NPY_cast_info_xfree(&y_cast_info);

if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
Py_DECREF(ret);
Expand All @@ -3419,7 +3442,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
Py_XDECREF(ax);
Py_XDECREF(ay);
Py_XDECREF(common_dt);
NPY_cast_info_xfree(&cast_info);
NPY_cast_info_xfree(&x_cast_info);
NPY_cast_info_xfree(&y_cast_info);
return NULL;
}

Expand Down
10 changes: 8 additions & 2 deletions 10 numpy/_core/src/multiarray/stringdtype/casts.c
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ string_to_bool(PyArrayMethod_Context *context, char *const data[],
npy_string_allocator *allocator = NpyString_acquire_allocator(descr);
int has_null = descr->na_object != NULL;
int has_string_na = descr->has_string_na;
int has_nan_na = descr->has_nan_na;
const npy_static_string *default_string = &descr->default_string;

npy_intp N = dimensions[0];
Expand All @@ -415,8 +416,13 @@ string_to_bool(PyArrayMethod_Context *context, char *const data[],
}
else if (is_null) {
if (has_null && !has_string_na) {
// numpy treats NaN as truthy, following python
*out = NPY_TRUE;
if (has_nan_na) {
// numpy treats NaN as truthy, following python
*out = NPY_TRUE;
}
else {
*out = NPY_FALSE;
}
}
else {
*out = (npy_bool)(default_string->size == 0);
Expand Down
17 changes: 16 additions & 1 deletion 17 numpy/_core/src/multiarray/stringdtype/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,23 @@ stringdtype_getitem(PyArray_StringDTypeObject *descr, char **dataptr)
// PyArray_NonzeroFunc
// Unicode strings are nonzero if their length is nonzero.
npy_bool
nonzero(void *data, void *NPY_UNUSED(arr))
nonzero(void *data, void *arr)
{
PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)PyArray_DESCR(arr);
int has_null = descr->na_object != NULL;
int has_nan_na = descr->has_nan_na;
int has_string_na = descr->has_string_na;
if (has_null && NpyString_isnull((npy_packed_static_string *)data)) {
if (!has_string_na) {
if (has_nan_na) {
// numpy treats NaN as truthy, following python
return 1;
}
else {
return 0;
}
}
}
return NpyString_size((npy_packed_static_string *)data) != 0;
}

Expand Down
44 changes: 39 additions & 5 deletions 44 numpy/_core/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,19 @@ def na_object(request):
return request.param


@pytest.fixture()
def dtype(na_object, coerce):
def get_dtype(na_object, coerce=True):
# explicit is check for pd_NA because != with pd_NA returns pd_NA
if na_object is pd_NA or na_object != "unset":
return StringDType(na_object=na_object, coerce=coerce)
else:
return StringDType(coerce=coerce)


@pytest.fixture()
def dtype(na_object, coerce):
return get_dtype(na_object, coerce)


# second copy for cast tests to do a cartesian product over dtypes
@pytest.fixture(params=[True, False])
def coerce2(request):
Expand Down Expand Up @@ -456,11 +460,41 @@ def test_sort(strings, arr_sorted):
["", "a", "😸", "ááðfáíóåéë"],
],
)
def test_nonzero(strings):
arr = np.array(strings, dtype="T")
is_nonzero = np.array([i for i, item in enumerate(arr) if len(item) != 0])
def test_nonzero(strings, na_object):
dtype = get_dtype(na_object)
arr = np.array(strings, dtype=dtype)
is_nonzero = np.array(
[i for i, item in enumerate(strings) if len(item) != 0])
assert_array_equal(arr.nonzero()[0], is_nonzero)

if na_object is not pd_NA and na_object == 'unset':
return

strings_with_na = np.array(strings + [na_object], dtype=dtype)
is_nan = np.isnan(np.array([dtype.na_object], dtype=dtype))[0]

if is_nan:
assert strings_with_na.nonzero()[0][-1] == 4
else:
assert strings_with_na.nonzero()[0][-1] == 3

# check that the casting to bool and nonzero give consistent results
assert_array_equal(strings_with_na[strings_with_na.nonzero()],
strings_with_na[strings_with_na.astype(bool)])


def test_where(string_list, na_object):
dtype = get_dtype(na_object)
a = np.array(string_list, dtype=dtype)
b = a[::-1]
res = np.where([True, False, True, False, True, False], a, b)
assert_array_equal(res, [a[0], b[1], a[2], b[3], a[4], b[5]])


def test_fancy_indexing(string_list):
sarr = np.array(string_list, dtype="T")
assert_array_equal(sarr, sarr[np.arange(sarr.shape[0])])


def test_creation_functions():
assert_array_equal(np.zeros(3, dtype="T"), ["", "", ""])
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.