Closed
Description
#10559 introduced the following code to prevent dispatching numpy.tensordot
in a case where einsum was broadcasting over a singleton dimension.
# Handle broadcasting vs BLAS cases
if blas:
# Checks have already been handled
input_str, results_index = einsum_str.split('->')
input_left, input_right = input_str.split(',')
if 1 in tmp_operands[0] or 1 in tmp_operands[1]:
left_dims = {dim: size for dim, size in
zip(input_left, tmp_operands[0].shape)}
right_dims = {dim: size for dim, size in
zip(input_right, tmp_operands[1].shape)}
# If dims do not match we are broadcasting, BLAS off
if any(left_dims[ind] != right_dims[ind] for ind in idx_rm):
blas = False
However, this checks to see if 1
occurs within the operand array itself rather than the shape of the operand. Incidentally, this likely produced a nasty performance regression.
Thus the line
if 1 in tmp_operands[0] or 1 in tmp_operands[1]
should be
if 1 in tmp_operands[0].shape or 1 in tmp_operands[1].shape
This wasn't caught by the unit test because arrays of ones were used 🌌
This leads to the following behavior:
>>> x = np.array([0., 1., 0.]) # contains 1, no blas
>>> y = np.array([0.0])
>>> np.einsum("i,i", x, y, optimize=True)
0.
>>> x = np.array([0., -1., 0.]) # doesn't contain 1, yes blas
>>> y = np.array([0.0])
>>> np.einsum("i,i", x, y, optimize=True)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-184-b0dcea8eedea> in <module>()
1 x = np.array([0., -1., 0.])
2 y = np.array([0.0])
----> 3 np.einsum("i,i", x, y, optimize=True)
c:\anaconda\envs\py36\lib\site-packages\numpy\core\einsumfunc.py in einsum(*operands, **kwargs)
1132
1133 # Contract!
-> 1134 new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
1135
1136 # Build a new view if needed
c:\anaconda\envs\py36\lib\site-packages\numpy\core\numeric.py in tensordot(a, b, axes)
1281 axes_b[k] += ndb
1282 if not equal:
-> 1283 raise ValueError("shape-mismatch for sum")
1284
1285 # Move the axes to sum over to the end of "a"
ValueError: shape-mismatch for sum
Metadata
Metadata
Assignees
Labels
No labels