Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c4ad02b

Browse filesBrowse files
committed
Simplifications to quiver3d.
- No need to convert input_args to lists: broadcast_arrays will handle the broadcasting just fine on scalars, and later using np.ravel() will ensure that the result is at least 1D. - The assertions don't really add much (and are clearly always true based on the code just above). - np.linspace(x, y, 2) is a pretty obfuscated way to write np.array([x, y]).
1 parent da58453 commit c4ad02b
Copy full SHA for c4ad02b

File tree

Expand file treeCollapse file tree

1 file changed

+3
-12
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+3
-12
lines changed

‎lib/mpl_toolkits/mplot3d/axes3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/axes3d.py
+3-12Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,15 +2637,12 @@ def calc_arrow(uvw, angle=15):
26372637

26382638
# first 6 arguments are X, Y, Z, U, V, W
26392639
input_args = args[:argi]
2640-
# if any of the args are scalar, convert into list
2641-
input_args = [[k] if isinstance(k, (int, float)) else k
2642-
for k in input_args]
26432640

26442641
# extract the masks, if any
26452642
masks = [k.mask for k in input_args
26462643
if isinstance(k, np.ma.MaskedArray)]
26472644
# broadcast to match the shape
2648-
bcast = np.broadcast_arrays(*(input_args + masks))
2645+
bcast = np.broadcast_arrays(*input_args, *masks)
26492646
input_args = bcast[:argi]
26502647
masks = bcast[argi:]
26512648
if masks:
@@ -2655,21 +2652,15 @@ def calc_arrow(uvw, angle=15):
26552652
input_args = [np.ma.array(k, mask=mask).compressed()
26562653
for k in input_args]
26572654
else:
2658-
input_args = [k.flatten() for k in input_args]
2655+
input_args = [np.ravel(k) for k in input_args]
26592656

26602657
if any(len(v) == 0 for v in input_args):
26612658
# No quivers, so just make an empty collection and return early
26622659
linec = art3d.Line3DCollection([], *args[argi:], **kwargs)
26632660
self.add_collection(linec)
26642661
return linec
26652662

2666-
# Following assertions must be true before proceeding
2667-
# must all be ndarray
2668-
assert all(isinstance(k, np.ndarray) for k in input_args)
2669-
# must all in same shape
2670-
assert len({k.shape for k in input_args}) == 1
2671-
2672-
shaft_dt = np.linspace(0, length, num=2)
2663+
shaft_dt = np.array([0, length])
26732664
arrow_dt = shaft_dt * arrow_length_ratio
26742665

26752666
if pivot == 'tail':

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.