Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6019840

Browse filesBrowse files
committed
BUG: Ensure that scalar binops prioritize __array_ufunc__
If array-ufunc is implemented, we must call always use it for all operators (that seems to be the promise). If __array_function__ is defined we are in the clear w.r.t. recursion because the object is either an array (can be unpacked, but already checked earlier now also), or it cannot call the ufunc without unpacking itself (otherwise it would cause recursion). There is an oddity about `__array_wrap__`. Rather than trying to do odd things to deal with it, I added a comment explaining why it doens't matter (roughly: don't use our scalar priority if you want to be sure to get a chance).
1 parent e4a495d commit 6019840
Copy full SHA for 6019840

File tree

2 files changed

+37
-4
lines changed
Filter options

2 files changed

+37
-4
lines changed

‎numpy/_core/src/multiarray/scalartypes.c.src

Copy file name to clipboardExpand all lines: numpy/_core/src/multiarray/scalartypes.c.src
+25-4Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,30 @@ find_binary_operation_path(
194194
PyLong_CheckExact(other) ||
195195
PyFloat_CheckExact(other) ||
196196
PyComplex_CheckExact(other) ||
197-
PyBool_Check(other)) {
197+
PyBool_Check(other) ||
198+
PyArray_Check(other)) {
198199
/*
199200
* The other operand is ready for the operation already. Must pass on
200201
* on float/long/complex mainly for weak promotion (NEP 50).
201202
*/
202-
Py_INCREF(other);
203-
*other_op = other;
203+
*other_op = Py_NewRef(other);
204204
return 0;
205205
}
206+
/*
207+
* If other has __array_ufunc__ we promise to use the ufunc and don't need
208+
* to worry about recursion (yet). Possible `__array_ufunc__ = None`
209+
* deferral was already dealt with here.
210+
* It may be nice to avoid double lookup in `BINOP_GIVE_UP_IF_NEEDED`.
211+
*/
212+
PyObject *attr = PyArray_LookupSpecial(other, npy_interned_str.array_ufunc);
213+
if (attr != NULL) {
214+
Py_DECREF(attr);
215+
*other_op = Py_NewRef(other);
216+
return 0;
217+
}
218+
else if (PyErr_Occurred()) {
219+
PyErr_Clear(); /* TODO[gh-14801]: propagate crashes during attribute access? */
220+
}
206221

207222
/*
208223
* Now check `other`. We want to know whether it is an object scalar
@@ -216,7 +231,13 @@ find_binary_operation_path(
216231
}
217232

218233
if (!was_scalar || PyArray_DESCR(arr)->type_num != NPY_OBJECT) {
219-
/* The array is OK for usage and we can simply forward it
234+
/*
235+
* The array is OK for usage and we can simply forward it. There
236+
* is a theoretical subtlety here: If the other object implements
237+
* `__array_wrap__`, we may ignore that. However, this only matters
238+
* if the other object has the identical `__array_priority__` and
239+
* additionally already deferred back to us.
240+
* (`obj + scalar` and `scalar + obj` are not symmetric.)
220241
*
221242
* NOTE: Future NumPy may need to distinguish scalars here, one option
222243
* could be marking the array.

‎numpy/_core/tests/test_multiarray.py

Copy file name to clipboardExpand all lines: numpy/_core/tests/test_multiarray.py
+12Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4025,6 +4025,18 @@ class LowPriority(np.ndarray):
40254025
assert res.shape == (3,)
40264026
assert res[0] == 'result'
40274027

4028+
@pytest.mark.parametrize("scalar", [
4029+
np.longdouble(1), np.timedelta64(120,'m')])
4030+
@pytest.mark.parametrize("op", [operator.add, operator.xor])
4031+
def test_scalar_binop_guarantees_ufunc(self, scalar, op):
4032+
# Test that __array_ufunc__ will always cause ufunc use even when
4033+
# we have to protect some other calls from recursing (see gh-26904).
4034+
class SomeClass:
4035+
def __array_ufunc__(self, ufunc, method, *inputs, **kw):
4036+
return "result"
4037+
4038+
assert SomeClass() + np.longdouble(1) == "result"
4039+
assert np.longdouble(1) + SomeClass() == "result"
40284040

40294041
def test_ufunc_override_normalize_signature(self):
40304042
# gh-5674

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.