diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index a70e547ef849..acc1913a402b 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -154,13 +154,8 @@ def __init__(self, *, """ artist.Artist.__init__(self) cm.ScalarMappable.__init__(self, norm, cmap) - # list of un-scaled dash patterns - # this is needed scaling the dash pattern by linewidth - self._us_linestyles = [(0, None)] - # list of dash patterns - self._linestyles = [(0, None)] - # list of unbroadcast/scaled linewidths - self._us_lw = [0] + self._unscaled_linestyles = [(0, None)] # dash patterns prior to scaling by lw + self._linestyles = [(0, None)] # dash patterns scaled by lw self._linewidths = [0] self._gapcolor = None # Currently only used by LineCollection. @@ -578,12 +573,9 @@ def set_linewidth(self, lw): """ if lw is None: lw = self._get_default_linewidth() - # get the un-scaled/broadcast lw - self._us_lw = np.atleast_1d(lw) - - # scale all of the dash patterns. - self._linewidths, self._linestyles = self._bcast_lwls( - self._us_lw, self._us_linestyles) + self._linewidths = np.atleast_1d(lw) + self._linestyles = self._compute_scaled_linestyles( + self._unscaled_linestyles, self._linewidths) self.stale = True def set_linestyle(self, ls): @@ -620,13 +612,9 @@ def set_linestyle(self, ls): except ValueError as err: emsg = f'Do not know how to convert {ls!r} to dashes' raise ValueError(emsg) from err - - # get the list of raw 'unscaled' dash patterns - self._us_linestyles = dashes - - # broadcast and scale the lw and dash patterns - self._linewidths, self._linestyles = self._bcast_lwls( - self._us_lw, self._us_linestyles) + self._unscaled_linestyles = dashes # raw 'unscaled' dash patterns + self._linestyles = self._compute_scaled_linestyles( + self._unscaled_linestyles, self._linewidths) @_docstring.interpd def set_capstyle(self, cs): @@ -657,42 +645,35 @@ def get_joinstyle(self): return self._joinstyle.name @staticmethod - def _bcast_lwls(linewidths, dashes): + def _compute_scaled_linestyles(unscaled_linestyles, linewidths): """ - Internal helper function to broadcast + scale ls/lw + Internal helper function to scale linestyles by linewidths. In the collection drawing code, the linewidth and linestyle are cycled through as circular buffers (via ``v[i % len(v)]``). Thus, if we are going to scale the dash pattern at set time (not draw time) we need to - do the broadcasting now and expand both lists to be the same length. + do the cycling now and expand both lists to be the same length. Parameters ---------- + unscaled_linestyles + dash specification (offset, (dash pattern tuple)) linewidths : list line widths of collection - dashes : list - dash specification (offset, (dash pattern tuple)) Returns ------- - linewidths, dashes : list + dashes : list Will be the same length, dashes are scaled by paired linewidth """ if mpl.rcParams['_internal.classic_mode']: - return linewidths, dashes - # make sure they are the same length so we can zip them - if len(dashes) != len(linewidths): - l_dashes = len(dashes) - l_lw = len(linewidths) - gcd = math.gcd(l_dashes, l_lw) - dashes = list(dashes) * (l_lw // gcd) - linewidths = list(linewidths) * (l_dashes // gcd) - - # scale the dash patterns - dashes = [mlines._scale_dashes(o, d, lw) - for (o, d), lw in zip(dashes, linewidths)] - - return linewidths, dashes + return unscaled_linestyles + n_ls = len(unscaled_linestyles) + n_lw = len(linewidths) + return [mlines._scale_dashes(o, d, lw) for (o, d), lw in zip( + # How many cycles do we need? + list(unscaled_linestyles) * (n_lw // math.gcd(n_ls, n_lw)), + itertools.cycle(linewidths))] def set_antialiased(self, aa): """ @@ -919,7 +900,7 @@ def update_from(self, other): self._facecolors = other._facecolors self._linewidths = other._linewidths self._linestyles = other._linestyles - self._us_linestyles = other._us_linestyles + self._unscaled_linestyles = other._unscaled_linestyles self._pickradius = other._pickradius self._hatch = other._hatch diff --git a/lib/matplotlib/legend_handler.py b/lib/matplotlib/legend_handler.py index 8aad6ce5219c..28349a29f36b 100644 --- a/lib/matplotlib/legend_handler.py +++ b/lib/matplotlib/legend_handler.py @@ -407,7 +407,7 @@ def get_numpoints(self, legend): def _default_update_prop(self, legend_handle, orig_handle): lw = orig_handle.get_linewidths()[0] - dashes = orig_handle._us_linestyles[0] + dashes = orig_handle._unscaled_linestyles[0] color = orig_handle.get_colors()[0] legend_handle.set_color(color) legend_handle.set_linestyle(dashes) diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 56a9c688af1f..331ff9e721f3 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -605,7 +605,7 @@ def test_lslw_bcast(): col.set_linewidths([1, 2, 3]) assert col.get_linestyles() == [(0, None)] * 6 - assert col.get_linewidths() == [1, 2, 3] * 2 + assert (col.get_linewidths() == [1, 2, 3]).all() col.set_linestyles(['-', '-', '-']) assert col.get_linestyles() == [(0, None)] * 3