Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f484fb4

Browse filesBrowse files
committed
Further shorten quiver3d computation...
... by using einsum instead of juggling axis order. This has essentially no effect on performance.
1 parent 30282cf commit f484fb4
Copy full SHA for f484fb4

File tree

Expand file treeCollapse file tree

1 file changed

+9
-22
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+9
-22
lines changed

‎lib/mpl_toolkits/mplot3d/axes3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/axes3d.py
+9-22Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,44 +2526,31 @@ def quiver(self, *args,
25262526
Any additional keyword arguments are delegated to
25272527
:class:`~matplotlib.collections.LineCollection`
25282528
"""
2529+
25292530
def calc_arrows(UVW, angle=15):
25302531
# get unit direction vector perpendicular to (u, v, w)
25312532
x = UVW[:, 0]
25322533
y = UVW[:, 1]
25332534
norm = np.linalg.norm(UVW[:, :2], axis=1)
25342535
x_p = np.divide(y, norm, where=norm != 0, out=np.zeros_like(x))
25352536
y_p = np.divide(-x, norm, where=norm != 0, out=np.ones_like(x))
2536-
25372537
# compute the two arrowhead direction unit vectors
25382538
ra = math.radians(angle)
25392539
c = math.cos(ra)
25402540
s = math.sin(ra)
2541-
2542-
# construct the rotation matrices
2541+
# construct the rotation matrices of shape (3, 3, n)
25432542
Rpos = np.array(
25442543
[[c + (x_p ** 2) * (1 - c), x_p * y_p * (1 - c), y_p * s],
25452544
[y_p * x_p * (1 - c), c + (y_p ** 2) * (1 - c), -x_p * s],
25462545
[-y_p * s, x_p * s, np.full_like(x_p, c)]])
2547-
Rpos = Rpos.transpose(2, 0, 1)
2548-
25492546
# opposite rotation negates all the sin terms
25502547
Rneg = Rpos.copy()
2551-
Rneg[:, [0, 1, 2, 2], [2, 2, 0, 1]] = \
2552-
-Rneg[:, [0, 1, 2, 2], [2, 2, 0, 1]]
2553-
2554-
# expand dimensions for batched matrix multiplication
2555-
UVW = np.expand_dims(UVW, axis=-1)
2556-
2557-
# multiply them to get the rotated vector
2558-
Rpos_vecs = np.matmul(Rpos, UVW)
2559-
Rneg_vecs = np.matmul(Rneg, UVW)
2560-
2561-
# transpose for concatenation
2562-
Rpos_vecs = Rpos_vecs.transpose(0, 2, 1)
2563-
Rneg_vecs = Rneg_vecs.transpose(0, 2, 1)
2564-
2565-
head_dirs = np.concatenate([Rpos_vecs, Rneg_vecs], axis=1)
2566-
2548+
Rneg[[0, 1, 2, 2], [2, 2, 0, 1]] *= -1
2549+
# Batch n (3, 3) x (3) matrix multiplications ((3, 3, n) x (n, 3)).
2550+
Rpos_vecs = np.einsum("ij...,...j->...i", Rpos, UVW)
2551+
Rneg_vecs = np.einsum("ij...,...j->...i", Rneg, UVW)
2552+
# Stack into (n, 2, 3) result.
2553+
head_dirs = np.stack([Rpos_vecs, Rneg_vecs], axis=1)
25672554
return head_dirs
25682555

25692556
had_data = self.has_data()
@@ -2630,7 +2617,7 @@ def calc_arrows(UVW, angle=15):
26302617
# compute all head lines at once, starting from the shaft ends
26312618
heads = shafts[:, :1] - np.multiply.outer(arrow_dt, head_dirs)
26322619
# stack left and right head lines together
2633-
heads.shape = (len(arrow_dt), -1, 3)
2620+
heads = heads.reshape((len(arrow_dt), -1, 3))
26342621
# transpose to get a list of lines
26352622
heads = heads.swapaxes(0, 1)
26362623

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.