Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Turn shared_axes, stale_viewlims into {axis_name: value} dicts. #20407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 64 additions & 60 deletions 124 lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ def _plot_args(self, tup, kwargs, return_kwargs=False):
class _AxesBase(martist.Artist):
name = "rectilinear"

_shared_x_axes = cbook.Grouper()
_shared_y_axes = cbook.Grouper()
_axis_names = ("x", "y") # See _get_axis_map.
_shared_axes = {name: cbook.Grouper() for name in _axis_names}
_twinned_axes = cbook.Grouper()

def __str__(self):
Expand Down Expand Up @@ -606,8 +606,7 @@ def __init__(self, fig, rect,
self._aspect = 'auto'
self._adjustable = 'box'
self._anchor = 'C'
self._stale_viewlim_x = False
self._stale_viewlim_y = False
self._stale_viewlims = {name: False for name in self._axis_names}
self._sharex = sharex
self._sharey = sharey
self.set_label(label)
Expand Down Expand Up @@ -685,20 +684,21 @@ def __getstate__(self):
# that point.
state = super().__getstate__()
# Prune the sharing & twinning info to only contain the current group.
for grouper_name in [
'_shared_x_axes', '_shared_y_axes', '_twinned_axes']:
grouper = getattr(self, grouper_name)
state[grouper_name] = (grouper.get_siblings(self)
if self in grouper else None)
state["_shared_axes"] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we break pickling-compatibility between old and implementation? If so this needs an API-change note.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pickling representation is explicitly not stable (see __mpl_version__ in Figure.__setstate__), changes fairly regularly without API note (nearly every time we add an attribute to Figure or Axes, we make old pickles unloadable on new versions, as they would be missing that attribute), and I wouldn't want to start pretending that we put any care in making the pickles stable.

name: self._shared_axes[name].get_siblings(self)
for name in self._axis_names if self in self._shared_axes[name]}
state["_twinned_axes"] = (self._twinned_axes.get_siblings(self)
if self in self._twinned_axes else None)
return state

def __setstate__(self, state):
# Merge the grouping info back into the global groupers.
for grouper_name in [
'_shared_x_axes', '_shared_y_axes', '_twinned_axes']:
siblings = state.pop(grouper_name)
if siblings:
getattr(self, grouper_name).join(*siblings)
shared_axes = state.pop("_shared_axes")
for name, shared_siblings in shared_axes.items():
self._shared_axes[name].join(*shared_siblings)
twinned_siblings = state.pop("_twinned_axes")
if twinned_siblings:
self._twinned_axes.join(*twinned_siblings)
self.__dict__ = state
self._stale = True

Expand Down Expand Up @@ -763,16 +763,16 @@ def set_figure(self, fig):
def _unstale_viewLim(self):
# We should arrange to store this information once per share-group
# instead of on every axis.
scalex = any(ax._stale_viewlim_x
for ax in self._shared_x_axes.get_siblings(self))
scaley = any(ax._stale_viewlim_y
for ax in self._shared_y_axes.get_siblings(self))
if scalex or scaley:
for ax in self._shared_x_axes.get_siblings(self):
ax._stale_viewlim_x = False
for ax in self._shared_y_axes.get_siblings(self):
ax._stale_viewlim_y = False
self.autoscale_view(scalex=scalex, scaley=scaley)
need_scale = {
name: any(ax._stale_viewlims[name]
for ax in self._shared_axes[name].get_siblings(self))
for name in self._axis_names}
if any(need_scale.values()):
for name in need_scale:
for ax in self._shared_axes[name].get_siblings(self):
ax._stale_viewlims[name] = False
self.autoscale_view(**{f"scale{name}": scale
for name, scale in need_scale.items()})

@property
def viewLim(self):
Expand All @@ -781,13 +781,22 @@ def viewLim(self):

# API could be better, right now this is just to match the old calls to
# autoscale_view() after each plotting method.
def _request_autoscale_view(self, tight=None, scalex=True, scaley=True):
def _request_autoscale_view(self, tight=None, **kwargs):
# kwargs are "scalex", "scaley" (& "scalez" for 3D) and default to True
want_scale = {name: True for name in self._axis_names}
for k, v in kwargs.items(): # Validate args before changing anything.
if k.startswith("scale"):
name = k[5:]
if name in want_scale:
want_scale[name] = v
continue
raise TypeError(
f"_request_autoscale_view() got an unexpected argument {k!r}")
if tight is not None:
self._tight = tight
if scalex:
self._stale_viewlim_x = True # Else keep old state.
if scaley:
self._stale_viewlim_y = True
for k, v in want_scale.items():
if v:
self._stale_viewlims[k] = True # Else keep old state.

def _set_lim_and_transforms(self):
"""
Expand Down Expand Up @@ -1141,7 +1150,7 @@ def sharex(self, other):
_api.check_isinstance(_AxesBase, other=other)
if self._sharex is not None and other is not self._sharex:
raise ValueError("x-axis is already shared")
self._shared_x_axes.join(self, other)
self._shared_axes["x"].join(self, other)
self._sharex = other
self.xaxis.major = other.xaxis.major # Ticker instances holding
self.xaxis.minor = other.xaxis.minor # locator and formatter.
Expand All @@ -1160,7 +1169,7 @@ def sharey(self, other):
_api.check_isinstance(_AxesBase, other=other)
if self._sharey is not None and other is not self._sharey:
raise ValueError("y-axis is already shared")
self._shared_y_axes.join(self, other)
self._shared_axes["y"].join(self, other)
self._sharey = other
self.yaxis.major = other.yaxis.major # Ticker instances holding
self.yaxis.minor = other.yaxis.minor # locator and formatter.
Expand Down Expand Up @@ -1289,8 +1298,8 @@ def cla(self):
self.xaxis.set_clip_path(self.patch)
self.yaxis.set_clip_path(self.patch)

self._shared_x_axes.clean()
self._shared_y_axes.clean()
self._shared_axes["x"].clean()
self._shared_axes["y"].clean()
if self._sharex is not None:
self.xaxis.set_visible(xaxis_visible)
self.patch.set_visible(patch_visible)
Expand Down Expand Up @@ -1620,8 +1629,8 @@ def set_aspect(self, aspect, adjustable=None, anchor=None, share=False):
aspect = float(aspect) # raise ValueError if necessary

if share:
axes = {*self._shared_x_axes.get_siblings(self),
*self._shared_y_axes.get_siblings(self)}
axes = {sibling for name in self._axis_names
for sibling in self._shared_axes[name].get_siblings(self)}
else:
axes = [self]

Expand Down Expand Up @@ -1682,8 +1691,8 @@ def set_adjustable(self, adjustable, share=False):
"""
_api.check_in_list(["box", "datalim"], adjustable=adjustable)
if share:
axs = {*self._shared_x_axes.get_siblings(self),
*self._shared_y_axes.get_siblings(self)}
axs = {sibling for name in self._axis_names
for sibling in self._shared_axes[name].get_siblings(self)}
else:
axs = [self]
if (adjustable == "datalim"
Expand Down Expand Up @@ -1803,8 +1812,8 @@ def set_anchor(self, anchor, share=False):
raise ValueError('argument must be among %s' %
', '.join(mtransforms.Bbox.coefs))
if share:
axes = {*self._shared_x_axes.get_siblings(self),
*self._shared_y_axes.get_siblings(self)}
axes = {sibling for name in self._axis_names
for sibling in self._shared_axes[name].get_siblings(self)}
else:
axes = [self]
for ax in axes:
Expand Down Expand Up @@ -1919,8 +1928,8 @@ def apply_aspect(self, position=None):
xm = 0
ym = 0

shared_x = self in self._shared_x_axes
shared_y = self in self._shared_y_axes
shared_x = self in self._shared_axes["x"]
shared_y = self in self._shared_axes["y"]
# Not sure whether we need this check:
if shared_x and shared_y:
raise RuntimeError("adjustable='datalim' is not allowed when both "
Expand Down Expand Up @@ -2830,13 +2839,13 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True):
if self._xmargin and scalex and self._autoscaleXon:
x_stickies = np.sort(np.concatenate([
artist.sticky_edges.x
for ax in self._shared_x_axes.get_siblings(self)
for ax in self._shared_axes["x"].get_siblings(self)
if hasattr(ax, "_children")
for artist in ax.get_children()]))
if self._ymargin and scaley and self._autoscaleYon:
y_stickies = np.sort(np.concatenate([
artist.sticky_edges.y
for ax in self._shared_y_axes.get_siblings(self)
for ax in self._shared_axes["y"].get_siblings(self)
if hasattr(ax, "_children")
for artist in ax.get_children()]))
if self.get_xscale() == 'log':
Expand Down Expand Up @@ -2910,14 +2919,14 @@ def handle_single_axis(scale, autoscaleon, shared_axes, name,
# End of definition of internal function 'handle_single_axis'.

handle_single_axis(
scalex, self._autoscaleXon, self._shared_x_axes, 'x',
scalex, self._autoscaleXon, self._shared_axes["x"], 'x',
self.xaxis, self._xmargin, x_stickies, self.set_xbound)
handle_single_axis(
scaley, self._autoscaleYon, self._shared_y_axes, 'y',
scaley, self._autoscaleYon, self._shared_axes["y"], 'y',
self.yaxis, self._ymargin, y_stickies, self.set_ybound)

def _get_axis_list(self):
return self.xaxis, self.yaxis
return tuple(getattr(self, f"{name}axis") for name in self._axis_names)

def _get_axis_map(self):
"""
Expand All @@ -2930,12 +2939,7 @@ def _get_axis_map(self):
In practice, this means that the entries are typically "x" and "y", and
additionally "z" for 3D axes.
"""
d = {}
axis_list = self._get_axis_list()
for k, v in vars(self).items():
if k.endswith("axis") and v in axis_list:
d[k[:-len("axis")]] = v
return d
return dict(zip(self._axis_names, self._get_axis_list()))

def _update_title_position(self, renderer):
"""
Expand Down Expand Up @@ -3706,15 +3710,15 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,

self._viewLim.intervalx = (left, right)
# Mark viewlims as no longer stale without triggering an autoscale.
for ax in self._shared_x_axes.get_siblings(self):
ax._stale_viewlim_x = False
for ax in self._shared_axes["x"].get_siblings(self):
ax._stale_viewlims["x"] = False
if auto is not None:
self._autoscaleXon = bool(auto)

if emit:
self.callbacks.process('xlim_changed', self)
# Call all of the other x-axes that are shared with this one
for other in self._shared_x_axes.get_siblings(self):
for other in self._shared_axes["x"].get_siblings(self):
if other is not self:
other.set_xlim(self.viewLim.intervalx,
emit=False, auto=auto)
Expand Down Expand Up @@ -4033,15 +4037,15 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,

self._viewLim.intervaly = (bottom, top)
# Mark viewlims as no longer stale without triggering an autoscale.
for ax in self._shared_y_axes.get_siblings(self):
ax._stale_viewlim_y = False
for ax in self._shared_axes["y"].get_siblings(self):
ax._stale_viewlims["y"] = False
if auto is not None:
self._autoscaleYon = bool(auto)

if emit:
self.callbacks.process('ylim_changed', self)
# Call all of the other y-axes that are shared with this one
for other in self._shared_y_axes.get_siblings(self):
for other in self._shared_axes["y"].get_siblings(self):
if other is not self:
other.set_ylim(self.viewLim.intervaly,
emit=False, auto=auto)
Expand Down Expand Up @@ -4705,8 +4709,8 @@ def twiny(self):

def get_shared_x_axes(self):
"""Return a reference to the shared axes Grouper object for x axes."""
return self._shared_x_axes
return self._shared_axes["x"]

def get_shared_y_axes(self):
"""Return a reference to the shared axes Grouper object for y axes."""
return self._shared_y_axes
return self._shared_axes["y"]
41 changes: 15 additions & 26 deletions 41 lib/matplotlib/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,16 +1517,13 @@ def set_units(self, u):
"""
if u == self.units:
return
if self is self.axes.xaxis:
shared = [
ax.xaxis
for ax in self.axes.get_shared_x_axes().get_siblings(self.axes)
]
elif self is self.axes.yaxis:
shared = [
ax.yaxis
for ax in self.axes.get_shared_y_axes().get_siblings(self.axes)
]
for name, axis in self.axes._get_axis_map().items():
if self is axis:
shared = [
getattr(ax, f"{name}axis")
for ax
in self.axes._shared_axes[name].get_siblings(self.axes)]
break
else:
shared = [self]
for axis in shared:
Expand Down Expand Up @@ -1800,21 +1797,13 @@ def _set_tick_locations(self, ticks, *, minor=False):

# XXX if the user changes units, the information will be lost here
ticks = self.convert_units(ticks)
if self is self.axes.xaxis:
shared = [
ax.xaxis
for ax in self.axes.get_shared_x_axes().get_siblings(self.axes)
]
elif self is self.axes.yaxis:
shared = [
ax.yaxis
for ax in self.axes.get_shared_y_axes().get_siblings(self.axes)
]
elif hasattr(self.axes, "zaxis") and self is self.axes.zaxis:
shared = [
ax.zaxis
for ax in self.axes._shared_z_axes.get_siblings(self.axes)
]
for name, axis in self.axes._get_axis_map().items():
if self is axis:
shared = [
getattr(ax, f"{name}axis")
for ax
in self.axes._shared_axes[name].get_siblings(self.axes)]
break
else:
shared = [self]
for axis in shared:
Expand Down Expand Up @@ -1882,7 +1871,7 @@ def _get_tick_boxes_siblings(self, renderer):
bboxes2 = []
# If we want to align labels from other axes:
for ax in grouper.get_siblings(self.axes):
axis = ax._get_axis_map()[axis_name]
axis = getattr(ax, f"{axis_name}axis")
ticks_to_draw = axis._update_ticks()
tlb, tlb2 = axis._get_tick_bboxes(ticks_to_draw, renderer)
bboxes.extend(tlb)
Expand Down
6 changes: 4 additions & 2 deletions 6 lib/matplotlib/colorbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,10 @@ def draw_all(self):
# also adds the outline path to self.outline spine:
self._do_extends(extendlen)

self.ax.set_xlim(self.vmin, self.vmax)
self.ax.set_ylim(self.vmin, self.vmax)
# These calls must be done on inner_ax, not ax (even though they mostly
# share internals), because otherwise viewLim unstaling gets confused.
self.ax.inner_ax.set_xlim(self.vmin, self.vmax)
self.ax.inner_ax.set_ylim(self.vmin, self.vmax)

# set up the tick locators and formatters. A bit complicated because
# boundary norms + uniform spacing requires a manual locator.
Expand Down
11 changes: 4 additions & 7 deletions 11 lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,13 +937,10 @@ def _break_share_link(ax, grouper):
self.stale = True
self._localaxes.remove(ax)

last_ax = _break_share_link(ax, ax._shared_y_axes)
if last_ax is not None:
_reset_locators_and_formatters(last_ax.yaxis)

last_ax = _break_share_link(ax, ax._shared_x_axes)
if last_ax is not None:
_reset_locators_and_formatters(last_ax.xaxis)
for name in ax._axis_names:
last_ax = _break_share_link(ax, ax._shared_axes[name])
if last_ax is not None:
_reset_locators_and_formatters(getattr(last_ax, f"{name}axis"))

# Note: in the docstring below, the newlines in the examples after the
# calls to legend() allow replacing it with figlegend() to generate the
Expand Down
4 changes: 1 addition & 3 deletions 4 lib/matplotlib/tests/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def check_shared(axs, x_shared, y_shared):
enumerate(zip("xy", [x_shared, y_shared]))):
if i2 <= i1:
continue
assert \
(getattr(axs[0], "_shared_{}_axes".format(name)).joined(ax1, ax2)
== shared[i1, i2]), \
assert axs[0]._shared_axes[name].joined(ax1, ax2) == shared[i1, i2], \
"axes %i and %i incorrectly %ssharing %s axis" % (
i1, i2, "not " if shared[i1, i2] else "", name)

Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.