Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit af1fcaa

Browse filesBrowse files
committed
BUG,ENH: Fix internal __array_wrap__ for direct calls
Since adding `return_scalar` as am argument, the array-wrap implementations were slightly wrong when that argument was actually passed and the function called directly. NumPy itself rarely (or never) did so for our builtin types now so that was not a problem within NumPy. Further, the scalar version was completely broken, converting to scalar even when such a conversion was impossible. As explained in the code. For array subclasses we NEVER want to convert to scalar by default. The subclass must make that choice explicitly. (There are plenty of tests for this behavior.)
1 parent fc7cc1e commit af1fcaa
Copy full SHA for af1fcaa

File tree

6 files changed

+100
-30
lines changed
Filter options

6 files changed

+100
-30
lines changed
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
* Calling ``__array_wrap__`` directly on NumPy arrays or scalars
2+
now does the right thing when ``return_scalar`` is passed
3+
(Added in NumPy 2). It is further safe now to call the scalar
4+
``__array_wrap__`` on a non-scalar result.

‎numpy/_core/src/multiarray/methods.c

Copy file name to clipboardExpand all lines: numpy/_core/src/multiarray/methods.c
+19-15Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -892,24 +892,24 @@ static PyObject *
892892
array_wraparray(PyArrayObject *self, PyObject *args)
893893
{
894894
PyArrayObject *arr;
895-
PyObject *obj;
895+
PyObject *context_ignored = NULL;
896+
/*
897+
* return_scalar should always be passed, if it is not default to how
898+
* this function behaved in older NumPy versions.
899+
*/
900+
int return_scalar = 0;
896901

897-
if (PyTuple_Size(args) < 1) {
898-
PyErr_SetString(PyExc_TypeError,
899-
"only accepts 1 argument");
902+
if (!PyArg_ParseTuple(args, "O!|OO&:__array_wrap__",
903+
&PyArray_Type, &arr, &context_ignored,
904+
&PyArray_OptionalBoolConverter, &return_scalar)) {
900905
return NULL;
901906
}
902-
obj = PyTuple_GET_ITEM(args, 0);
903-
if (obj == NULL) {
904-
return NULL;
905-
}
906-
if (!PyArray_Check(obj)) {
907-
PyErr_SetString(PyExc_TypeError,
908-
"can only be called with ndarray object");
909-
return NULL;
910-
}
911-
arr = (PyArrayObject *)obj;
912907

908+
/*
909+
* Subclasses must implement `__array_wrap__` with all the arguments.
910+
* If they do not, we default to never returning a scalar to allow
911+
* preserving the (presumably important) subclass information.
912+
*/
913913
if (Py_TYPE(self) != Py_TYPE(arr)) {
914914
PyArray_Descr *dtype = PyArray_DESCR(arr);
915915
Py_INCREF(dtype);
@@ -919,7 +919,11 @@ array_wraparray(PyArrayObject *self, PyObject *args)
919919
PyArray_NDIM(arr),
920920
PyArray_DIMS(arr),
921921
PyArray_STRIDES(arr), PyArray_DATA(arr),
922-
PyArray_FLAGS(arr), (PyObject *)self, obj);
922+
PyArray_FLAGS(arr), (PyObject *)self, (PyObject *)arr);
923+
}
924+
else if (return_scalar) {
925+
Py_INCREF(arr);
926+
return PyArray_Return(arr);
923927
}
924928
else {
925929
/*

‎numpy/_core/src/multiarray/scalartypes.c.src

Copy file name to clipboardExpand all lines: numpy/_core/src/multiarray/scalartypes.c.src
+19-14Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "numpyos.h"
2626
#include "can_cast_table.h"
2727
#include "common.h"
28+
#include "conversion_utils.h"
2829
#include "flagsobject.h"
2930
#include "scalartypes.h"
3031
#include "_datetime.h"
@@ -2035,29 +2036,33 @@ gentype_getarray(PyObject *scalar, PyObject *args)
20352036
return ret;
20362037
}
20372038

2038-
static char doc_sc_wraparray[] = "sc.__array_wrap__(obj) return scalar from array";
2039+
static char doc_sc_wraparray[] = "Array wrap to implementation for scalar types";
20392040

2041+
/*
2042+
* Array wrap for scalars, returning a scalar again preferentially.
2043+
* (note that NumPy itself may well never call this itself).
2044+
*/
20402045
static PyObject *
20412046
gentype_wraparray(PyObject *NPY_UNUSED(scalar), PyObject *args)
20422047
{
2043-
PyObject *obj;
20442048
PyArrayObject *arr;
2049+
PyObject *context_ignored = NULL;
2050+
/* return_scalar should be passed, but it was a scalar so prefer scalar */
2051+
int return_scalar = 1;
20452052

2046-
if (PyTuple_Size(args) < 1) {
2047-
PyErr_SetString(PyExc_TypeError,
2048-
"only accepts 1 argument.");
2053+
if (!PyArg_ParseTuple(args, "O!|OO&:__array_wrap__",
2054+
&PyArray_Type, &arr, &context_ignored,
2055+
&PyArray_OptionalBoolConverter, &return_scalar)) {
20492056
return NULL;
20502057
}
2051-
obj = PyTuple_GET_ITEM(args, 0);
2052-
if (!PyArray_Check(obj)) {
2053-
PyErr_SetString(PyExc_TypeError,
2054-
"can only be called with ndarray object");
2055-
return NULL;
2056-
}
2057-
arr = (PyArrayObject *)obj;
20582058

2059-
return PyArray_Scalar(PyArray_DATA(arr),
2060-
PyArray_DESCR(arr), (PyObject *)arr);
2059+
Py_INCREF(arr);
2060+
if (!return_scalar) {
2061+
return (PyObject *)arr;
2062+
}
2063+
else {
2064+
return PyArray_Return(arr);
2065+
}
20612066
}
20622067

20632068
/*

‎numpy/_core/tests/test_arrayobject.py

Copy file name to clipboardExpand all lines: numpy/_core/tests/test_arrayobject.py
+36Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,39 @@ def test_matrix_transpose_equals_swapaxes(shape):
3131
tgt = np.swapaxes(arr, num_of_axes - 2, num_of_axes - 1)
3232
mT = arr.mT
3333
assert_array_equal(tgt, mT)
34+
35+
36+
@pytest.mark.parametrize("subclass", [False, True])
37+
def test_array_wrap(subclass):
38+
# NumPy should allow `__array_wrap__` to be called on arrays, it's logic
39+
# is designed in a way that:
40+
#
41+
# * Subclasses never return scalars by default (to preserve their
42+
# information). They can choose to if they wish.
43+
# * NumPy returns scalars, if `return_scalar` is passed as True to allow
44+
# manual calls to `arr.__array_wrap__` to do the right thing.
45+
46+
class MyArr(np.ndarray):
47+
def __array_wrap__(self, arr, context=None, return_scalar=None):
48+
return super().__array_wrap__(arr, context, return_scalar)
49+
50+
arr = np.arange(3)
51+
if subclass:
52+
arr = arr.view(MyArr)
53+
54+
arr0d = np.array(3, dtype=np.int8)
55+
# Third argument not passed, None, or True "decays" to scalar.
56+
# (I don't think NumPy would pass `None`, but it seems clear to support)
57+
if not subclass:
58+
assert type(arr.__array_wrap__(arr0d, None, True)) is np.int8
59+
else:
60+
assert type(arr.__array_wrap__(arr0d, None, True)) is type(arr)
61+
62+
# Otherwise, result should be viewed as the subclass
63+
assert type(arr.__array_wrap__(arr0d)) is type(arr)
64+
assert type(arr.__array_wrap__(arr0d, None, None)) is type(arr)
65+
assert type(arr.__array_wrap__(arr0d, None, False)) is type(arr)
66+
67+
# Non 0-D array can't be converted to scalar, so we ignore that
68+
arr1d = np.array([3], dtype=np.int8)
69+
assert type(arr.__array_wrap__(arr1d, None, True)) is type(arr)

‎numpy/_core/tests/test_multiarray.py

Copy file name to clipboardExpand all lines: numpy/_core/tests/test_multiarray.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9807,7 +9807,7 @@ class MyArr(np.ndarray):
98079807

98089808
def __array_wrap__(self, new, context=None, return_scalar=False):
98099809
type(self).called_wrap += 1
9810-
return super().__array_wrap__(new)
9810+
return super().__array_wrap__(new, context, return_scalar)
98119811

98129812
numpy_arr = np.zeros(5, dtype=dt1)
98139813
my_arr = np.zeros(5, dtype=dt2).view(MyArr)

‎numpy/_core/tests/test_scalar_methods.py

Copy file name to clipboardExpand all lines: numpy/_core/tests/test_scalar_methods.py
+21Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,24 @@ def test_to_device(self, scalar):
223223
@pytest.mark.parametrize("scalar", scalars)
224224
def test___array_namespace__(self, scalar):
225225
assert scalar.__array_namespace__() is np
226+
227+
228+
@pytest.mark.parametrize("scalar", [np.bool(True), np.int8(1), np.float64(1)])
229+
def test_array_wrap(scalar):
230+
# Test scalars array wrap as long as it exists. NumPy itself should
231+
# probably not use it, so it may not be necessary to keep it around.
232+
233+
arr0d = np.array(3, dtype=np.int8)
234+
# Third argument not passed, None, or True "decays" to scalar.
235+
# (I don't think NumPy would pass `None`, but it seems clear to support)
236+
assert type(scalar.__array_wrap__(arr0d)) is np.int8
237+
assert type(scalar.__array_wrap__(arr0d, None, None)) is np.int8
238+
assert type(scalar.__array_wrap__(arr0d, None, True)) is np.int8
239+
240+
# Otherwise, result should be the input
241+
assert scalar.__array_wrap__(arr0d, None, False) is arr0d
242+
243+
# An old bug. A non 0-d array cannot be converted to scalar:
244+
arr1d = np.array([3], dtype=np.int8)
245+
assert scalar.__array_wrap__(arr1d) is arr1d
246+
assert scalar.__array_wrap__(arr1d, None, True) is arr1d

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.