Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ee9500a

Browse filesBrowse files
authored
Merge pull request #19812 from tacaswell/fix_scatter3D
FIX: size and color rendering for Path3DCollection
2 parents 45e3d1f + f888285 commit ee9500a
Copy full SHA for ee9500a

File tree

3 files changed

+117
-149
lines changed
Filter options

3 files changed

+117
-149
lines changed

‎lib/mpl_toolkits/mplot3d/art3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/art3d.py
+104-149Lines changed: 104 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,6 @@ def do_3d_projection(self, renderer=None):
302302
"""
303303
Project the points according to renderer matrix.
304304
"""
305-
# see _update_scalarmappable docstring for why this must be here
306-
_update_scalarmappable(self)
307305
xyslist = [proj3d.proj_trans_points(points, self.axes.M)
308306
for points in self._segments3d]
309307
segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
@@ -448,16 +446,6 @@ def set_depthshade(self, depthshade):
448446
self._depthshade = depthshade
449447
self.stale = True
450448

451-
def set_facecolor(self, c):
452-
# docstring inherited
453-
super().set_facecolor(c)
454-
self._facecolor3d = self.get_facecolor()
455-
456-
def set_edgecolor(self, c):
457-
# docstring inherited
458-
super().set_edgecolor(c)
459-
self._edgecolor3d = self.get_edgecolor()
460-
461449
def set_sort_zpos(self, val):
462450
"""Set the position to use for z-sorting."""
463451
self._sort_zpos = val
@@ -474,34 +462,43 @@ def set_3d_properties(self, zs, zdir):
474462
xs = []
475463
ys = []
476464
self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir)
477-
self._facecolor3d = self.get_facecolor()
478-
self._edgecolor3d = self.get_edgecolor()
465+
self._vzs = None
479466
self.stale = True
480467

481468
@_api.delete_parameter('3.4', 'renderer')
482469
def do_3d_projection(self, renderer=None):
483-
# see _update_scalarmappable docstring for why this must be here
484-
_update_scalarmappable(self)
485470
xs, ys, zs = self._offsets3d
486471
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
487472
self.axes.M)
488-
489-
fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else
490-
self._facecolor3d)
491-
fcs = mcolors.to_rgba_array(fcs, self._alpha)
492-
super().set_facecolor(fcs)
493-
494-
ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else
495-
self._edgecolor3d)
496-
ecs = mcolors.to_rgba_array(ecs, self._alpha)
497-
super().set_edgecolor(ecs)
473+
self._vzs = vzs
498474
super().set_offsets(np.column_stack([vxs, vys]))
499475

500476
if vzs.size > 0:
501477
return min(vzs)
502478
else:
503479
return np.nan
504480

481+
def _maybe_depth_shade_and_sort_colors(self, color_array):
482+
color_array = (
483+
_zalpha(color_array, self._vzs)
484+
if self._vzs is not None and self._depthshade
485+
else color_array
486+
)
487+
if len(color_array) > 1:
488+
color_array = color_array[self._z_markers_idx]
489+
return mcolors.to_rgba_array(color_array, self._alpha)
490+
491+
def get_facecolor(self):
492+
return self._maybe_depth_shade_and_sort_colors(super().get_facecolor())
493+
494+
def get_edgecolor(self):
495+
# We need this check here to make sure we do not double-apply the depth
496+
# based alpha shading when the edge color is "face" which means the
497+
# edge colour should be identical to the face colour.
498+
if cbook._str_equal(self._edgecolors, 'face'):
499+
return self.get_facecolor()
500+
return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor())
501+
505502

506503
class Path3DCollection(PathCollection):
507504
"""
@@ -525,9 +522,14 @@ def __init__(self, *args, zs=0, zdir='z', depthshade=True, **kwargs):
525522
This is typically desired in scatter plots.
526523
"""
527524
self._depthshade = depthshade
525+
self._in_draw = False
528526
super().__init__(*args, **kwargs)
529527
self.set_3d_properties(zs, zdir)
530528

529+
def draw(self, renderer):
530+
with cbook._setattr_cm(self, _in_draw=True):
531+
super().draw(renderer)
532+
531533
def set_sort_zpos(self, val):
532534
"""Set the position to use for z-sorting."""
533535
self._sort_zpos = val
@@ -544,12 +546,37 @@ def set_3d_properties(self, zs, zdir):
544546
xs = []
545547
ys = []
546548
self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir)
547-
self._facecolor3d = self.get_facecolor()
548-
self._edgecolor3d = self.get_edgecolor()
549-
self._sizes3d = self.get_sizes()
550-
self._linewidth3d = self.get_linewidth()
549+
# In the base draw methods we access the attributes directly which
550+
# means we can not resolve the shuffling in the getter methods like
551+
# we do for the edge and face colors.
552+
#
553+
# This means we need to carry around a cache of the unsorted sizes and
554+
# widths (postfixed with 3d) and in `do_3d_projection` set the
555+
# depth-sorted version of that data into the private state used by the
556+
# base collection class in its draw method.
557+
#
558+
# grab the current sizes and linewidths to preserve them
559+
self._sizes3d = self._sizes
560+
self._linewidths3d = self._linewidths
561+
xs, ys, zs = self._offsets3d
562+
563+
# Sort the points based on z coordinates
564+
# Performance optimization: Create a sorted index array and reorder
565+
# points and point properties according to the index array
566+
self._z_markers_idx = slice(-1)
567+
self._vzs = None
551568
self.stale = True
552569

570+
def set_sizes(self, sizes, dpi=72.0):
571+
super().set_sizes(sizes, dpi)
572+
if not self._in_draw:
573+
self._sizes3d = sizes
574+
575+
def set_linewidth(self, lw):
576+
super().set_linewidth(lw)
577+
if not self._in_draw:
578+
self._linewidth3d = lw
579+
553580
def get_depthshade(self):
554581
return self._depthshade
555582

@@ -566,140 +593,57 @@ def set_depthshade(self, depthshade):
566593
self._depthshade = depthshade
567594
self.stale = True
568595

569-
def set_facecolor(self, c):
570-
# docstring inherited
571-
super().set_facecolor(c)
572-
self._facecolor3d = self.get_facecolor()
573-
574-
def set_edgecolor(self, c):
575-
# docstring inherited
576-
super().set_edgecolor(c)
577-
self._edgecolor3d = self.get_edgecolor()
578-
579-
def set_sizes(self, sizes, dpi=72.0):
580-
# docstring inherited
581-
super().set_sizes(sizes, dpi=dpi)
582-
self._sizes3d = self.get_sizes()
583-
584-
def set_linewidth(self, lw):
585-
# docstring inherited
586-
super().set_linewidth(lw)
587-
self._linewidth3d = self.get_linewidth()
588-
589596
@_api.delete_parameter('3.4', 'renderer')
590597
def do_3d_projection(self, renderer=None):
591-
# see _update_scalarmappable docstring for why this must be here
592-
_update_scalarmappable(self)
593598
xs, ys, zs = self._offsets3d
594599
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
595600
self.axes.M)
596-
597-
fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else
598-
self._facecolor3d)
599-
ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else
600-
self._edgecolor3d)
601-
sizes = self._sizes3d
602-
lws = self._linewidth3d
603-
604601
# Sort the points based on z coordinates
605602
# Performance optimization: Create a sorted index array and reorder
606603
# points and point properties according to the index array
607-
z_markers_idx = np.argsort(vzs)[::-1]
604+
z_markers_idx = self._z_markers_idx = np.argsort(vzs)[::-1]
605+
self._vzs = vzs
606+
607+
# we have to special case the sizes because of code in collections.py
608+
# as the draw method does
609+
# self.set_sizes(self._sizes, self.figure.dpi)
610+
# so we can not rely on doing the sorting on the way out via get_*
611+
612+
if len(self._sizes3d) > 1:
613+
self._sizes = self._sizes3d[z_markers_idx]
614+
615+
if len(self._linewidths3d) > 1:
616+
self._linewidths = self._linewidths3d[z_markers_idx]
608617

609618
# Re-order items
610619
vzs = vzs[z_markers_idx]
611620
vxs = vxs[z_markers_idx]
612621
vys = vys[z_markers_idx]
613-
if len(fcs) > 1:
614-
fcs = fcs[z_markers_idx]
615-
if len(ecs) > 1:
616-
ecs = ecs[z_markers_idx]
617-
if len(sizes) > 1:
618-
sizes = sizes[z_markers_idx]
619-
if len(lws) > 1:
620-
lws = lws[z_markers_idx]
621-
vps = np.column_stack((vxs, vys))
622-
623-
fcs = mcolors.to_rgba_array(fcs, self._alpha)
624-
ecs = mcolors.to_rgba_array(ecs, self._alpha)
625-
626-
super().set_edgecolor(ecs)
627-
super().set_facecolor(fcs)
628-
super().set_sizes(sizes)
629-
super().set_linewidth(lws)
630-
631-
PathCollection.set_offsets(self, vps)
632622

633-
return np.min(vzs) if vzs.size else np.nan
623+
PathCollection.set_offsets(self, np.column_stack((vxs, vys)))
634624

625+
return np.min(vzs) if vzs.size else np.nan
635626

636-
def _update_scalarmappable(sm):
637-
"""
638-
Update a 3D ScalarMappable.
639-
640-
With ScalarMappable objects if the data, colormap, or norm are
641-
changed, we need to update the computed colors. This is handled
642-
by the base class method update_scalarmappable. This method works
643-
by detecting if work needs to be done, and if so stashing it on
644-
the ``self._facecolors`` attribute.
645-
646-
With 3D collections we internally sort the components so that
647-
things that should be "in front" are rendered later to simulate
648-
having a z-buffer (in addition to doing the projections). This is
649-
handled in the ``do_3d_projection`` methods which are called from the
650-
draw method of the 3D Axes. These methods:
651-
652-
- do the projection from 3D -> 2D
653-
- internally sort based on depth
654-
- stash the results of the above in the 2D analogs of state
655-
- return the z-depth of the whole artist
656-
657-
the last step is so that we can, at the Axes level, sort the children by
658-
depth.
659-
660-
The base `draw` method of the 2D artists unconditionally calls
661-
update_scalarmappable and rely on the method's internal caching logic to
662-
lazily evaluate.
663-
664-
These things together mean you can have the sequence of events:
665-
666-
- we create the artist, do the color mapping and stash the results
667-
in a 3D specific state.
668-
- change something about the ScalarMappable that marks it as in
669-
need of an update (`ScalarMappable.changed` and friends).
670-
- We call do_3d_projection and shuffle the stashed colors into the
671-
2D version of face colors
672-
- the draw method calls the update_scalarmappable method which
673-
overwrites our shuffled colors
674-
- we get a render that is wrong
675-
- if we re-render (either with a second save or implicitly via
676-
tight_layout / constrained_layout / bbox_inches='tight' (ex via
677-
inline's defaults)) we again shuffle the 3D colors
678-
- because the CM is not marked as changed update_scalarmappable is
679-
a no-op and we get a correct looking render.
680-
681-
This function is an internal helper to:
682-
683-
- sort out if we need to do the color mapping at all (has data!)
684-
- sort out if update_scalarmappable is going to be a no-op
685-
- copy the data over from the 2D -> 3D version
686-
687-
This must be called first thing in do_3d_projection to make sure that
688-
the correct colors get shuffled.
627+
def _maybe_depth_shade_and_sort_colors(self, color_array):
628+
color_array = (
629+
_zalpha(color_array, self._vzs)
630+
if self._vzs is not None and self._depthshade
631+
else color_array
632+
)
633+
if len(color_array) > 1:
634+
color_array = color_array[self._z_markers_idx]
635+
return mcolors.to_rgba_array(color_array, self._alpha)
689636

690-
Parameters
691-
----------
692-
sm : ScalarMappable
693-
The ScalarMappable to update and stash the 3D data from
637+
def get_facecolor(self):
638+
return self._maybe_depth_shade_and_sort_colors(super().get_facecolor())
694639

695-
"""
696-
if sm._A is None:
697-
return
698-
sm.update_scalarmappable()
699-
if sm._face_is_mapped:
700-
sm._facecolor3d = sm._facecolors
701-
elif sm._edge_is_mapped: # Should this be plain "if"?
702-
sm._edgecolor3d = sm._edgecolors
640+
def get_edgecolor(self):
641+
# We need this check here to make sure we do not double-apply the depth
642+
# based alpha shading when the edge color is "face" which means the
643+
# edge colour should be identical to the face colour.
644+
if cbook._str_equal(self._edgecolors, 'face'):
645+
return self.get_facecolor()
646+
return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor())
703647

704648

705649
def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
@@ -725,6 +669,7 @@ def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
725669
elif isinstance(col, PatchCollection):
726670
col.__class__ = Patch3DCollection
727671
col._depthshade = depthshade
672+
col._in_draw = False
728673
col.set_3d_properties(zs, zdir)
729674

730675

@@ -839,9 +784,19 @@ def do_3d_projection(self, renderer=None):
839784
"""
840785
Perform the 3D projection for this object.
841786
"""
842-
# see _update_scalarmappable docstring for why this must be here
843-
_update_scalarmappable(self)
844-
787+
if self._A is not None:
788+
# force update of color mapping because we re-order them
789+
# below. If we do not do this here, the 2D draw will call
790+
# this, but we will never port the color mapped values back
791+
# to the 3D versions.
792+
#
793+
# We hold the 3D versions in a fixed order (the order the user
794+
# passed in) and sort the 2D version by view depth.
795+
self.update_scalarmappable()
796+
if self._face_is_mapped:
797+
self._facecolor3d = self._facecolors
798+
if self._edge_is_mapped:
799+
self._edgecolor3d = self._edgecolors
845800
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, self.axes.M)
846801
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
847802

Loading

‎lib/mpl_toolkits/tests/test_mplot3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/tests/test_mplot3d.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,3 +1510,16 @@ def test_computed_zorder():
15101510
zorder=4)
15111511
ax.view_init(azim=-20, elev=20)
15121512
ax.axis('off')
1513+
1514+
1515+
@image_comparison(baseline_images=['scatter_spiral.png'],
1516+
remove_text=True,
1517+
style='default')
1518+
def test_scatter_spiral():
1519+
fig = plt.figure()
1520+
ax = fig.add_subplot(projection='3d')
1521+
th = np.linspace(0, 2 * np.pi * 6, 256)
1522+
sc = ax.scatter(np.sin(th), np.cos(th), th, s=(1 + th * 5), c=th ** 2)
1523+
1524+
# force at least 1 draw!
1525+
fig.canvas.draw()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.