Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8633bbd

Browse filesBrowse files
committed
Store vertices as (masked) arrays in Poly3DCollection
This removes a bunch of very expensive list comprehensions, dropping rendering times by half for some 3D surface plots.
1 parent ef9fc20 commit 8633bbd
Copy full SHA for 8633bbd

File tree

Expand file treeCollapse file tree

2 files changed

+82
-30
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+82
-30
lines changed

‎lib/mpl_toolkits/mplot3d/art3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/art3d.py
+52-27Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -598,16 +598,33 @@ def set_zsort(self, zsort):
598598
self.stale = True
599599

600600
def get_vector(self, segments3d):
601-
"""Optimize points for projection."""
602-
if len(segments3d):
603-
xs, ys, zs = np.row_stack(segments3d).T
604-
else: # row_stack can't stack zero arrays.
605-
xs, ys, zs = [], [], []
606-
ones = np.ones(len(xs))
607-
self._vec = np.array([xs, ys, zs, ones])
601+
"""
602+
Optimize points for projection.
608603
609-
indices = [0, *np.cumsum([len(segment) for segment in segments3d])]
610-
self._segslices = [*map(slice, indices[:-1], indices[1:])]
604+
Parameters
605+
----------
606+
segments3d : NumPy array or list of NumPy arrays
607+
List of vertices of the boundary of every segment. If all paths are
608+
of equal length and this argument is a NumPy arrray, then it should
609+
be of shape (num_faces, num_vertices, 3).
610+
"""
611+
if isinstance(segments3d, np.ndarray):
612+
if segments3d.ndim != 3 or segments3d.shape[-1] != 3:
613+
raise ValueError("segments3d must be a MxNx3 array, but got " +
614+
"shape {}".format(segments3d.shape))
615+
self._segments = segments3d
616+
else:
617+
num_faces = len(segments3d)
618+
num_verts = np.fromiter(map(len, segments3d), dtype=np.intp)
619+
max_verts = num_verts.max(initial=0)
620+
padded = np.empty((num_faces, max_verts, 3))
621+
for i, face in enumerate(segments3d):
622+
padded[i, :len(face)] = face
623+
mask = np.arange(max_verts) >= num_verts[:, None]
624+
mask = mask[..., None] # add a component axis
625+
# ma.array does not broadcast the mask for us
626+
mask = np.broadcast_to(mask, padded.shape)
627+
self._segments = np.ma.array(padded, mask=mask)
611628

612629
def set_verts(self, verts, closed=True):
613630
"""Set 3D vertices."""
@@ -649,37 +666,45 @@ def do_3d_projection(self, renderer):
649666
self.update_scalarmappable()
650667
self._facecolors3d = self._facecolors
651668

652-
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, renderer.M)
653-
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
669+
psegments = proj3d._proj_transform_vectors(self._segments, renderer.M)
670+
is_masked = isinstance(psegments, np.ma.MaskedArray)
671+
num_faces = len(psegments)
654672

655673
# This extra fuss is to re-order face / edge colors
656674
cface = self._facecolors3d
657675
cedge = self._edgecolors3d
658-
if len(cface) != len(xyzlist):
659-
cface = cface.repeat(len(xyzlist), axis=0)
660-
if len(cedge) != len(xyzlist):
676+
if len(cface) != num_faces:
677+
cface = cface.repeat(num_faces, axis=0)
678+
if len(cedge) != num_faces:
661679
if len(cedge) == 0:
662680
cedge = cface
663681
else:
664-
cedge = cedge.repeat(len(xyzlist), axis=0)
682+
cedge = cedge.repeat(num_faces, axis=0)
665683

666-
# sort by depth (furthest drawn first)
667-
z_segments_2d = sorted(
668-
((self._zsortfunc(zs), np.column_stack([xs, ys]), fc, ec, idx)
669-
for idx, ((xs, ys, zs), fc, ec)
670-
in enumerate(zip(xyzlist, cface, cedge))),
671-
key=lambda x: x[0], reverse=True)
684+
face_z = self._zsortfunc(psegments[..., 2], axis=-1)
685+
if is_masked:
686+
# NOTE: Unpacking .data is safe here, because every face has to
687+
# contain a valid vertex.
688+
face_z = face_z.data
689+
face_order = np.argsort(face_z, axis=-1)[::-1]
672690

673-
segments_2d = [s for z, s, fc, ec, idx in z_segments_2d]
691+
segments_2d = psegments[face_order, :, :2]
674692
if self._codes3d is not None:
675-
codes = [self._codes3d[idx] for z, s, fc, ec, idx in z_segments_2d]
693+
if is_masked:
694+
# NOTE: We cannot assert the same on segments_2d, as it is a
695+
# result of advanced indexing, so its mask is a newly
696+
# allocated tensor. However, both of those asserts are
697+
# equivalent.
698+
assert psegments.mask.strides[-1] == 0
699+
segments_2d = [s.compressed().reshape(-1, 2) for s in segments_2d]
700+
codes = [self._codes3d[idx] for idx in face_order]
676701
PolyCollection.set_verts_and_codes(self, segments_2d, codes)
677702
else:
678703
PolyCollection.set_verts(self, segments_2d, self._closed)
679704

680-
self._facecolors2d = [fc for z, s, fc, ec, idx in z_segments_2d]
705+
self._facecolors2d = cface[face_order]
681706
if len(self._edgecolors3d) == len(cface):
682-
self._edgecolors2d = [ec for z, s, fc, ec, idx in z_segments_2d]
707+
self._edgecolors2d = cedge[face_order]
683708
else:
684709
self._edgecolors2d = self._edgecolors3d
685710

@@ -688,11 +713,11 @@ def do_3d_projection(self, renderer):
688713
zvec = np.array([[0], [0], [self._sort_zpos], [1]])
689714
ztrans = proj3d._proj_transform_vec(zvec, renderer.M)
690715
return ztrans[2][0]
691-
elif tzs.size > 0:
716+
elif psegments.size > 0:
692717
# FIXME: Some results still don't look quite right.
693718
# In particular, examine contourf3d_demo2.py
694719
# with az = -54 and elev = -45.
695-
return np.min(tzs)
720+
return np.min(psegments[..., 2])
696721
else:
697722
return np.nan
698723

‎lib/mpl_toolkits/mplot3d/proj3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/proj3d.py
+30-3Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,8 @@ def ortho_transformation(zfront, zback):
103103

104104
def _proj_transform_vec(vec, M):
105105
vecw = np.dot(M, vec)
106-
w = vecw[3]
107106
# clip here..
108-
txs, tys, tzs = vecw[0]/w, vecw[1]/w, vecw[2]/w
109-
return txs, tys, tzs
107+
return vecw[:3] / vecw[3]
110108

111109

112110
def _proj_transform_vec_clip(vec, M):
@@ -146,6 +144,35 @@ def proj_transform(xs, ys, zs, M):
146144
transform = proj_transform
147145

148146

147+
def _proj_transform_vectors(vecs, M):
148+
"""Vector version of ``project_transform`` able to handle MaskedArrays.
149+
150+
Parameters
151+
----------
152+
vecs : ... x 3 np.ndarray or np.ma.MaskedArray
153+
Input vectors
154+
155+
M : 4 x 4 np.ndarray
156+
Projection matrix
157+
"""
158+
vecs_shape = vecs.shape
159+
vecs = vecs.reshape(-1, 3).T
160+
161+
is_masked = isinstance(vecs, np.ma.MaskedArray)
162+
if is_masked:
163+
mask = vecs.mask
164+
165+
vecs_pad = np.empty((vecs.shape[0] + 1,) + vecs.shape[1:])
166+
vecs_pad[:-1] = vecs
167+
vecs_pad[-1] = 1
168+
product = np.dot(M, vecs_pad)
169+
tvecs = product[:3] / product[3]
170+
171+
if is_masked:
172+
tvecs = np.ma.array(tvecs, mask=mask)
173+
return tvecs.T.reshape(vecs_shape)
174+
175+
149176
def proj_transform_clip(xs, ys, zs, M):
150177
"""
151178
Transform the points by the projection matrix

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.