Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 29f3a5c

Browse filesBrowse files
authored
Merge pull request matplotlib#29397 from scottshambaugh/masked_array_performance
3D plotting performance improvements
2 parents 743a005 + 4fc3745 commit 29f3a5c
Copy full SHA for 29f3a5c

File tree

6 files changed

+187
-92
lines changed
Filter options

6 files changed

+187
-92
lines changed
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
3D performance improvements
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
4+
Draw time for 3D plots has been improved, especially for surface and wireframe
5+
plots. Users should see up to a 10x speedup in some cases. This should make
6+
interacting with 3D plots much more responsive.

‎lib/mpl_toolkits/mplot3d/art3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/art3d.py
+150-81Lines changed: 150 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def get_dir_vector(zdir):
7575

7676
def _viewlim_mask(xs, ys, zs, axes):
7777
"""
78-
Return original points with points outside the axes view limits masked.
78+
Return the mask of the points outside the axes view limits.
7979
8080
Parameters
8181
----------
@@ -86,19 +86,16 @@ def _viewlim_mask(xs, ys, zs, axes):
8686
8787
Returns
8888
-------
89-
xs_masked, ys_masked, zs_masked : np.ma.array
90-
The masked points.
89+
mask : np.array
90+
The mask of the points as a bool array.
9191
"""
9292
mask = np.logical_or.reduce((xs < axes.xy_viewLim.xmin,
9393
xs > axes.xy_viewLim.xmax,
9494
ys < axes.xy_viewLim.ymin,
9595
ys > axes.xy_viewLim.ymax,
9696
zs < axes.zz_viewLim.xmin,
9797
zs > axes.zz_viewLim.xmax))
98-
xs_masked = np.ma.array(xs, mask=mask)
99-
ys_masked = np.ma.array(ys, mask=mask)
100-
zs_masked = np.ma.array(zs, mask=mask)
101-
return xs_masked, ys_masked, zs_masked
98+
return mask
10299

103100

104101
class Text3D(mtext.Text):
@@ -182,14 +179,13 @@ def set_3d_properties(self, z=0, zdir='z', axlim_clip=False):
182179
@artist.allow_rasterization
183180
def draw(self, renderer):
184181
if self._axlim_clip:
185-
xs, ys, zs = _viewlim_mask(self._x, self._y, self._z, self.axes)
186-
position3d = np.ma.row_stack((xs, ys, zs)).ravel().filled(np.nan)
182+
mask = _viewlim_mask(self._x, self._y, self._z, self.axes)
183+
pos3d = np.ma.array([self._x, self._y, self._z],
184+
mask=mask, dtype=float).filled(np.nan)
187185
else:
188-
xs, ys, zs = self._x, self._y, self._z
189-
position3d = np.asanyarray([xs, ys, zs])
186+
pos3d = np.array([self._x, self._y, self._z], dtype=float)
190187

191-
proj = proj3d._proj_trans_points(
192-
[position3d, position3d + self._dir_vec], self.axes.M)
188+
proj = proj3d._proj_trans_points([pos3d, pos3d + self._dir_vec], self.axes.M)
193189
dx = proj[0][1] - proj[0][0]
194190
dy = proj[1][1] - proj[1][0]
195191
angle = math.degrees(math.atan2(dy, dx))
@@ -313,7 +309,12 @@ def get_data_3d(self):
313309
@artist.allow_rasterization
314310
def draw(self, renderer):
315311
if self._axlim_clip:
316-
xs3d, ys3d, zs3d = _viewlim_mask(*self._verts3d, self.axes)
312+
mask = np.broadcast_to(
313+
_viewlim_mask(*self._verts3d, self.axes),
314+
(len(self._verts3d), *self._verts3d[0].shape)
315+
)
316+
xs3d, ys3d, zs3d = np.ma.array(self._verts3d,
317+
dtype=float, mask=mask).filled(np.nan)
317318
else:
318319
xs3d, ys3d, zs3d = self._verts3d
319320
xs, ys, zs, tis = proj3d._proj_transform_clip(xs3d, ys3d, zs3d,
@@ -404,7 +405,8 @@ def do_3d_projection(self):
404405
"""Project the points according to renderer matrix."""
405406
vs_list = [vs for vs, _ in self._3dverts_codes]
406407
if self._axlim_clip:
407-
vs_list = [np.ma.row_stack(_viewlim_mask(*vs.T, self.axes)).T
408+
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
409+
_viewlim_mask(*vs.T, self.axes), vs.shape))
408410
for vs in vs_list]
409411
xyzs_list = [proj3d.proj_transform(*vs.T, self.axes.M) for vs in vs_list]
410412
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
@@ -450,22 +452,32 @@ def do_3d_projection(self):
450452
"""
451453
Project the points according to renderer matrix.
452454
"""
453-
segments = self._segments3d
455+
segments = np.asanyarray(self._segments3d)
456+
457+
mask = False
458+
if np.ma.isMA(segments):
459+
mask = segments.mask
460+
454461
if self._axlim_clip:
455-
all_points = np.ma.vstack(segments)
456-
masked_points = np.ma.column_stack([*_viewlim_mask(*all_points.T,
457-
self.axes)])
458-
segment_lengths = [np.shape(segment)[0] for segment in segments]
459-
segments = np.split(masked_points, np.cumsum(segment_lengths[:-1]))
460-
xyslist = [proj3d._proj_trans_points(points, self.axes.M)
461-
for points in segments]
462-
segments_2d = [np.ma.column_stack([xs, ys]) for xs, ys, zs in xyslist]
462+
viewlim_mask = _viewlim_mask(segments[..., 0],
463+
segments[..., 1],
464+
segments[..., 2],
465+
self.axes)
466+
if np.any(viewlim_mask):
467+
# broadcast mask to 3D
468+
viewlim_mask = np.broadcast_to(viewlim_mask[..., np.newaxis],
469+
(*viewlim_mask.shape, 3))
470+
mask = mask | viewlim_mask
471+
xyzs = np.ma.array(proj3d._proj_transform_vectors(segments, self.axes.M),
472+
mask=mask)
473+
segments_2d = xyzs[..., 0:2]
463474
LineCollection.set_segments(self, segments_2d)
464475

465476
# FIXME
466-
minz = 1e9
467-
for xs, ys, zs in xyslist:
468-
minz = min(minz, min(zs))
477+
if len(xyzs) > 0:
478+
minz = min(xyzs[..., 2].min(), 1e9)
479+
else:
480+
minz = np.nan
469481
return minz
470482

471483

@@ -531,7 +543,9 @@ def get_path(self):
531543
def do_3d_projection(self):
532544
s = self._segment3d
533545
if self._axlim_clip:
534-
xs, ys, zs = _viewlim_mask(*zip(*s), self.axes)
546+
mask = _viewlim_mask(*zip(*s), self.axes)
547+
xs, ys, zs = np.ma.array(zip(*s),
548+
dtype=float, mask=mask).filled(np.nan)
535549
else:
536550
xs, ys, zs = zip(*s)
537551
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
@@ -587,7 +601,9 @@ def set_3d_properties(self, path, zs=0, zdir='z', axlim_clip=False):
587601
def do_3d_projection(self):
588602
s = self._segment3d
589603
if self._axlim_clip:
590-
xs, ys, zs = _viewlim_mask(*zip(*s), self.axes)
604+
mask = _viewlim_mask(*zip(*s), self.axes)
605+
xs, ys, zs = np.ma.array(zip(*s),
606+
dtype=float, mask=mask).filled(np.nan)
591607
else:
592608
xs, ys, zs = zip(*s)
593609
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
@@ -701,14 +717,18 @@ def set_3d_properties(self, zs, zdir, axlim_clip=False):
701717

702718
def do_3d_projection(self):
703719
if self._axlim_clip:
704-
xs, ys, zs = _viewlim_mask(*self._offsets3d, self.axes)
720+
mask = _viewlim_mask(*self._offsets3d, self.axes)
721+
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
705722
else:
706723
xs, ys, zs = self._offsets3d
707724
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
708725
self.axes.M,
709726
self.axes._focal_length)
710727
self._vzs = vzs
711-
super().set_offsets(np.ma.column_stack([vxs, vys]))
728+
if np.ma.isMA(vxs):
729+
super().set_offsets(np.ma.column_stack([vxs, vys]))
730+
else:
731+
super().set_offsets(np.column_stack([vxs, vys]))
712732

713733
if vzs.size > 0:
714734
return min(vzs)
@@ -851,11 +871,18 @@ def set_depthshade(self, depthshade):
851871
self.stale = True
852872

853873
def do_3d_projection(self):
874+
mask = False
875+
for xyz in self._offsets3d:
876+
if np.ma.isMA(xyz):
877+
mask = mask | xyz.mask
854878
if self._axlim_clip:
855-
xs, ys, zs = _viewlim_mask(*self._offsets3d, self.axes)
879+
mask = mask | _viewlim_mask(*self._offsets3d, self.axes)
880+
mask = np.broadcast_to(mask,
881+
(len(self._offsets3d), *self._offsets3d[0].shape))
882+
xyzs = np.ma.array(self._offsets3d, mask=mask)
856883
else:
857-
xs, ys, zs = self._offsets3d
858-
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
884+
xyzs = self._offsets3d
885+
vxs, vys, vzs, vis = proj3d._proj_transform_clip(*xyzs,
859886
self.axes.M,
860887
self.axes._focal_length)
861888
# Sort the points based on z coordinates
@@ -1062,16 +1089,37 @@ def get_vector(self, segments3d):
10621089
return self._get_vector(segments3d)
10631090

10641091
def _get_vector(self, segments3d):
1065-
"""Optimize points for projection."""
1066-
if len(segments3d):
1067-
xs, ys, zs = np.vstack(segments3d).T
1068-
else: # vstack can't stack zero arrays.
1069-
xs, ys, zs = [], [], []
1070-
ones = np.ones(len(xs))
1071-
self._vec = np.array([xs, ys, zs, ones])
1092+
"""
1093+
Optimize points for projection.
10721094
1073-
indices = [0, *np.cumsum([len(segment) for segment in segments3d])]
1074-
self._segslices = [*map(slice, indices[:-1], indices[1:])]
1095+
Parameters
1096+
----------
1097+
segments3d : NumPy array or list of NumPy arrays
1098+
List of vertices of the boundary of every segment. If all paths are
1099+
of equal length and this argument is a NumPy array, then it should
1100+
be of shape (num_faces, num_vertices, 3).
1101+
"""
1102+
if isinstance(segments3d, np.ndarray):
1103+
if segments3d.ndim != 3 or segments3d.shape[-1] != 3:
1104+
raise ValueError("segments3d must be a MxNx3 array, but got "
1105+
f"shape {segments3d.shape}")
1106+
if isinstance(segments3d, np.ma.MaskedArray):
1107+
self._faces = segments3d.data
1108+
self._invalid_vertices = segments3d.mask.any(axis=-1)
1109+
else:
1110+
self._faces = segments3d
1111+
self._invalid_vertices = False
1112+
else:
1113+
# Turn the potentially ragged list into a numpy array for later speedups
1114+
# If it is ragged, set the unused vertices per face as invalid
1115+
num_faces = len(segments3d)
1116+
num_verts = np.fromiter(map(len, segments3d), dtype=np.intp)
1117+
max_verts = num_verts.max(initial=0)
1118+
segments = np.empty((num_faces, max_verts, 3))
1119+
for i, face in enumerate(segments3d):
1120+
segments[i, :len(face)] = face
1121+
self._faces = segments
1122+
self._invalid_vertices = np.arange(max_verts) >= num_verts[:, None]
10751123

10761124
def set_verts(self, verts, closed=True):
10771125
"""
@@ -1133,64 +1181,85 @@ def do_3d_projection(self):
11331181
self._facecolor3d = self._facecolors
11341182
if self._edge_is_mapped:
11351183
self._edgecolor3d = self._edgecolors
1184+
1185+
needs_masking = np.any(self._invalid_vertices)
1186+
num_faces = len(self._faces)
1187+
mask = self._invalid_vertices
1188+
1189+
# Some faces might contain masked vertices, so we want to ignore any
1190+
# errors that those might cause
1191+
with np.errstate(invalid='ignore', divide='ignore'):
1192+
pfaces = proj3d._proj_transform_vectors(self._faces, self.axes.M)
1193+
11361194
if self._axlim_clip:
1137-
xs, ys, zs = _viewlim_mask(*self._vec[0:3], self.axes)
1138-
if self._vec.shape[0] == 4: # Will be 3 (xyz) or 4 (xyzw)
1139-
w_masked = np.ma.masked_where(zs.mask, self._vec[3])
1140-
vec = np.ma.array([xs, ys, zs, w_masked])
1141-
else:
1142-
vec = np.ma.array([xs, ys, zs])
1143-
else:
1144-
vec = self._vec
1145-
txs, tys, tzs = proj3d._proj_transform_vec(vec, self.axes.M)
1146-
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
1195+
viewlim_mask = _viewlim_mask(self._faces[..., 0], self._faces[..., 1],
1196+
self._faces[..., 2], self.axes)
1197+
if np.any(viewlim_mask):
1198+
needs_masking = True
1199+
mask = mask | viewlim_mask
1200+
1201+
pzs = pfaces[..., 2]
1202+
if needs_masking:
1203+
pzs = np.ma.MaskedArray(pzs, mask=mask)
11471204

11481205
# This extra fuss is to re-order face / edge colors
11491206
cface = self._facecolor3d
11501207
cedge = self._edgecolor3d
1151-
if len(cface) != len(xyzlist):
1152-
cface = cface.repeat(len(xyzlist), axis=0)
1153-
if len(cedge) != len(xyzlist):
1208+
if len(cface) != num_faces:
1209+
cface = cface.repeat(num_faces, axis=0)
1210+
if len(cedge) != num_faces:
11541211
if len(cedge) == 0:
11551212
cedge = cface
11561213
else:
1157-
cedge = cedge.repeat(len(xyzlist), axis=0)
1158-
1159-
if xyzlist:
1160-
# sort by depth (furthest drawn first)
1161-
z_segments_2d = sorted(
1162-
((self._zsortfunc(zs.data), np.ma.column_stack([xs, ys]), fc, ec, idx)
1163-
for idx, ((xs, ys, zs), fc, ec)
1164-
in enumerate(zip(xyzlist, cface, cedge))),
1165-
key=lambda x: x[0], reverse=True)
1166-
1167-
_, segments_2d, self._facecolors2d, self._edgecolors2d, idxs = \
1168-
zip(*z_segments_2d)
1169-
else:
1170-
segments_2d = []
1171-
self._facecolors2d = np.empty((0, 4))
1172-
self._edgecolors2d = np.empty((0, 4))
1173-
idxs = []
1174-
1175-
if self._codes3d is not None:
1176-
codes = [self._codes3d[idx] for idx in idxs]
1177-
PolyCollection.set_verts_and_codes(self, segments_2d, codes)
1214+
cedge = cedge.repeat(num_faces, axis=0)
1215+
1216+
if len(pzs) > 0:
1217+
face_z = self._zsortfunc(pzs, axis=-1)
11781218
else:
1179-
PolyCollection.set_verts(self, segments_2d, self._closed)
1219+
face_z = pzs
1220+
if needs_masking:
1221+
face_z = face_z.data
1222+
face_order = np.argsort(face_z, axis=-1)[::-1]
11801223

1181-
if len(self._edgecolor3d) != len(cface):
1224+
if len(pfaces) > 0:
1225+
faces_2d = pfaces[face_order, :, :2]
1226+
else:
1227+
faces_2d = pfaces
1228+
if self._codes3d is not None and len(self._codes3d) > 0:
1229+
if needs_masking:
1230+
segment_mask = ~mask[face_order, :]
1231+
faces_2d = [face[mask, :] for face, mask
1232+
in zip(faces_2d, segment_mask)]
1233+
codes = [self._codes3d[idx] for idx in face_order]
1234+
PolyCollection.set_verts_and_codes(self, faces_2d, codes)
1235+
else:
1236+
if needs_masking and len(faces_2d) > 0:
1237+
invalid_vertices_2d = np.broadcast_to(
1238+
mask[face_order, :, None],
1239+
faces_2d.shape)
1240+
faces_2d = np.ma.MaskedArray(
1241+
faces_2d, mask=invalid_vertices_2d)
1242+
PolyCollection.set_verts(self, faces_2d, self._closed)
1243+
1244+
if len(cface) > 0:
1245+
self._facecolors2d = cface[face_order]
1246+
else:
1247+
self._facecolors2d = cface
1248+
if len(self._edgecolor3d) == len(cface) and len(cedge) > 0:
1249+
self._edgecolors2d = cedge[face_order]
1250+
else:
11821251
self._edgecolors2d = self._edgecolor3d
11831252

11841253
# Return zorder value
11851254
if self._sort_zpos is not None:
11861255
zvec = np.array([[0], [0], [self._sort_zpos], [1]])
11871256
ztrans = proj3d._proj_transform_vec(zvec, self.axes.M)
11881257
return ztrans[2][0]
1189-
elif tzs.size > 0:
1258+
elif pzs.size > 0:
11901259
# FIXME: Some results still don't look quite right.
11911260
# In particular, examine contourf3d_demo2.py
11921261
# with az = -54 and elev = -45.
1193-
return np.min(tzs)
1262+
return np.min(pzs)
11941263
else:
11951264
return np.nan
11961265

‎lib/mpl_toolkits/mplot3d/axes3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/axes3d.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2892,7 +2892,9 @@ def add_collection3d(self, col, zs=0, zdir='z', autolim=True, *,
28922892
self.auto_scale_xyz(*np.array(col._segments3d).transpose(),
28932893
had_data=had_data)
28942894
elif isinstance(col, art3d.Poly3DCollection):
2895-
self.auto_scale_xyz(*col._vec[:-1], had_data=had_data)
2895+
self.auto_scale_xyz(col._faces[..., 0],
2896+
col._faces[..., 1],
2897+
col._faces[..., 2], had_data=had_data)
28962898
elif isinstance(col, art3d.Patch3DCollection):
28972899
pass
28982900
# FIXME: Implement auto-scaling function for Patch3DCollection

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.