Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e3a34c5

Browse filesBrowse files
authored
Merge pull request #9049 from dopplershift/fix-8908
BUG: Fix weird behavior with mask and units (Fixes #8908)
2 parents fe0095e + 0525c6b commit e3a34c5
Copy full SHA for e3a34c5

File tree

Expand file treeCollapse file tree

8 files changed

+105
-53
lines changed
Filter options
Expand file treeCollapse file tree

8 files changed

+105
-53
lines changed

‎lib/matplotlib/axes/_base.py

Copy file name to clipboardExpand all lines: lib/matplotlib/axes/_base.py
+13-4Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,12 @@ def __init__(self, fig, rect,
556556
self.update(kwargs)
557557

558558
if self.xaxis is not None:
559-
self._xcid = self.xaxis.callbacks.connect('units finalize',
560-
self.relim)
559+
self._xcid = self.xaxis.callbacks.connect(
560+
'units finalize', lambda: self._on_units_changed(scalex=True))
561561

562562
if self.yaxis is not None:
563-
self._ycid = self.yaxis.callbacks.connect('units finalize',
564-
self.relim)
563+
self._ycid = self.yaxis.callbacks.connect(
564+
'units finalize', lambda: self._on_units_changed(scaley=True))
565565

566566
self.tick_params(
567567
top=rcParams['xtick.top'] and rcParams['xtick.minor.top'],
@@ -1891,6 +1891,15 @@ def add_container(self, container):
18911891
container.set_remove_method(lambda h: self.containers.remove(h))
18921892
return container
18931893

1894+
def _on_units_changed(self, scalex=False, scaley=False):
1895+
"""
1896+
Callback for processing changes to axis units.
1897+
1898+
Currently forces updates of data limits and view limits.
1899+
"""
1900+
self.relim()
1901+
self.autoscale_view(scalex=scalex, scaley=scaley)
1902+
18941903
def relim(self, visible_only=False):
18951904
"""
18961905
Recompute the data limits based on current artists. If you want to

‎lib/matplotlib/cbook/__init__.py

Copy file name to clipboardExpand all lines: lib/matplotlib/cbook/__init__.py
+12-1Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,17 @@ def is_math_text(s):
19971997
return even_dollars
19981998

19991999

2000+
def _to_unmasked_float_array(x):
2001+
"""
2002+
Convert a sequence to a float array; if input was a masked array, masked
2003+
values are converted to nans.
2004+
"""
2005+
if hasattr(x, 'mask'):
2006+
return np.ma.asarray(x, float).filled(np.nan)
2007+
else:
2008+
return np.asarray(x, float)
2009+
2010+
20002011
def _check_1d(x):
20012012
'''
20022013
Converts a sequence of less than 1 dimension, to an array of 1
@@ -2283,7 +2294,7 @@ def index_of(y):
22832294
try:
22842295
return y.index.values, y.values
22852296
except AttributeError:
2286-
y = np.atleast_1d(y)
2297+
y = _check_1d(y)
22872298
return np.arange(y.shape[0], dtype=float), y
22882299

22892300

‎lib/matplotlib/lines.py

Copy file name to clipboardExpand all lines: lib/matplotlib/lines.py
+4-11Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from . import artist, colors as mcolors, docstring, rcParams
1717
from .artist import Artist, allow_rasterization
1818
from .cbook import (
19-
iterable, is_numlike, ls_mapper, ls_mapper_r, STEP_LOOKUP_MAP)
19+
_to_unmasked_float_array, iterable, is_numlike, ls_mapper, ls_mapper_r,
20+
STEP_LOOKUP_MAP)
2021
from .markers import MarkerStyle
2122
from .path import Path
2223
from .transforms import Bbox, TransformedPath, IdentityTransform
@@ -648,20 +649,12 @@ def recache_always(self):
648649
def recache(self, always=False):
649650
if always or self._invalidx:
650651
xconv = self.convert_xunits(self._xorig)
651-
if isinstance(self._xorig, np.ma.MaskedArray):
652-
x = np.ma.asarray(xconv, float).filled(np.nan)
653-
else:
654-
x = np.asarray(xconv, float)
655-
x = x.ravel()
652+
x = _to_unmasked_float_array(xconv).ravel()
656653
else:
657654
x = self._x
658655
if always or self._invalidy:
659656
yconv = self.convert_yunits(self._yorig)
660-
if isinstance(self._yorig, np.ma.MaskedArray):
661-
y = np.ma.asarray(yconv, float).filled(np.nan)
662-
else:
663-
y = np.asarray(yconv, float)
664-
y = y.ravel()
657+
y = _to_unmasked_float_array(yconv).ravel()
665658
else:
666659
y = self._y
667660

‎lib/matplotlib/path.py

Copy file name to clipboardExpand all lines: lib/matplotlib/path.py
+4-11Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import numpy as np
2424

2525
from . import _path, rcParams
26-
from .cbook import simple_linear_interpolation, maxdict
26+
from .cbook import (_to_unmasked_float_array, simple_linear_interpolation,
27+
maxdict)
2728

2829

2930
class Path(object):
@@ -132,11 +133,7 @@ def __init__(self, vertices, codes=None, _interpolation_steps=1,
132133
Makes the path behave in an immutable way and sets the vertices
133134
and codes as read-only arrays.
134135
"""
135-
if isinstance(vertices, np.ma.MaskedArray):
136-
vertices = vertices.astype(float).filled(np.nan)
137-
else:
138-
vertices = np.asarray(vertices, float)
139-
136+
vertices = _to_unmasked_float_array(vertices)
140137
if (vertices.ndim != 2) or (vertices.shape[1] != 2):
141138
msg = "'vertices' must be a 2D list or array with shape Nx2"
142139
raise ValueError(msg)
@@ -188,11 +185,7 @@ def _fast_from_codes_and_verts(cls, verts, codes, internals=None):
188185
"""
189186
internals = internals or {}
190187
pth = cls.__new__(cls)
191-
if isinstance(verts, np.ma.MaskedArray):
192-
verts = verts.astype(float).filled(np.nan)
193-
else:
194-
verts = np.asarray(verts, float)
195-
pth._vertices = verts
188+
pth._vertices = _to_unmasked_float_array(verts)
196189
pth._codes = codes
197190
pth._readonly = internals.pop('readonly', False)
198191
pth.should_simplify = internals.pop('should_simplify', True)
Loading
Loading

‎lib/matplotlib/tests/test_units.py

Copy file name to clipboard
+58-21Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from matplotlib.cbook import iterable
12
import matplotlib.pyplot as plt
3+
from matplotlib.testing.decorators import image_comparison
24
import matplotlib.units as munits
35
import numpy as np
46

@@ -9,49 +11,84 @@
911
from mock import MagicMock
1012

1113

12-
# Tests that the conversion machinery works properly for classes that
13-
# work as a facade over numpy arrays (like pint)
14-
def test_numpy_facade():
15-
# Basic class that wraps numpy array and has units
16-
class Quantity(object):
17-
def __init__(self, data, units):
18-
self.magnitude = data
19-
self.units = units
14+
# Basic class that wraps numpy array and has units
15+
class Quantity(object):
16+
def __init__(self, data, units):
17+
self.magnitude = data
18+
self.units = units
19+
20+
def to(self, new_units):
21+
factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
22+
('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
23+
('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
24+
if self.units != new_units:
25+
mult = factors[self.units, new_units]
26+
return Quantity(mult * self.magnitude, new_units)
27+
else:
28+
return Quantity(self.magnitude, self.units)
29+
30+
def __getattr__(self, attr):
31+
return getattr(self.magnitude, attr)
2032

21-
def to(self, new_units):
22-
return Quantity(self.magnitude, new_units)
33+
def __getitem__(self, item):
34+
return Quantity(self.magnitude[item], self.units)
2335

24-
def __getattr__(self, attr):
25-
return getattr(self.magnitude, attr)
36+
def __array__(self):
37+
return np.asarray(self.magnitude)
2638

27-
def __getitem__(self, item):
28-
return self.magnitude[item]
2939

40+
# Tests that the conversion machinery works properly for classes that
41+
# work as a facade over numpy arrays (like pint)
42+
@image_comparison(baseline_images=['plot_pint'],
43+
extensions=['png'], remove_text=False, style='mpl20')
44+
def test_numpy_facade():
3045
# Create an instance of the conversion interface and
3146
# mock so we can check methods called
3247
qc = munits.ConversionInterface()
3348

3449
def convert(value, unit, axis):
3550
if hasattr(value, 'units'):
36-
return value.to(unit)
51+
return value.to(unit).magnitude
52+
elif iterable(value):
53+
try:
54+
return [v.to(unit).magnitude for v in value]
55+
except AttributeError:
56+
return [Quantity(v, axis.get_units()).to(unit).magnitude
57+
for v in value]
3758
else:
3859
return Quantity(value, axis.get_units()).to(unit).magnitude
3960

4061
qc.convert = MagicMock(side_effect=convert)
41-
qc.axisinfo = MagicMock(return_value=None)
62+
qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u))
4263
qc.default_units = MagicMock(side_effect=lambda x, a: x.units)
4364

4465
# Register the class
4566
munits.registry[Quantity] = qc
4667

4768
# Simple test
48-
t = Quantity(np.linspace(0, 10), 'sec')
49-
d = Quantity(30 * np.linspace(0, 10), 'm/s')
69+
y = Quantity(np.linspace(0, 30), 'miles')
70+
x = Quantity(np.linspace(0, 5), 'hours')
5071

51-
fig, ax = plt.subplots(1, 1)
52-
l, = plt.plot(t, d)
53-
ax.yaxis.set_units('inch')
72+
fig, ax = plt.subplots()
73+
fig.subplots_adjust(left=0.15) # Make space for label
74+
ax.plot(x, y, 'tab:blue')
75+
ax.axhline(Quantity(26400, 'feet'), color='tab:red')
76+
ax.axvline(Quantity(120, 'minutes'), color='tab:green')
77+
ax.yaxis.set_units('inches')
78+
ax.xaxis.set_units('seconds')
5479

5580
assert qc.convert.called
5681
assert qc.axisinfo.called
5782
assert qc.default_units.called
83+
84+
85+
# Tests gh-8908
86+
@image_comparison(baseline_images=['plot_masked_units'],
87+
extensions=['png'], remove_text=True, style='mpl20')
88+
def test_plot_masked_units():
89+
data = np.linspace(-5, 5)
90+
data_masked = np.ma.array(data, mask=(data > -2) & (data < 2))
91+
data_masked_units = Quantity(data_masked, 'meters')
92+
93+
fig, ax = plt.subplots()
94+
ax.plot(data_masked_units)

‎lib/mpl_toolkits/mplot3d/axes3d.py

Copy file name to clipboardExpand all lines: lib/mpl_toolkits/mplot3d/axes3d.py
+14-5Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def __init__(self, fig, rect=None, *args, **kwargs):
111111
# func used to format z -- fall back on major formatters
112112
self.fmt_zdata = None
113113

114-
if zscale is not None :
114+
if zscale is not None:
115115
self.set_zscale(zscale)
116116

117-
if self.zaxis is not None :
118-
self._zcid = self.zaxis.callbacks.connect('units finalize',
119-
self.relim)
120-
else :
117+
if self.zaxis is not None:
118+
self._zcid = self.zaxis.callbacks.connect(
119+
'units finalize', lambda: self._on_units_changed(scalez=True))
120+
else:
121121
self._zcid = None
122122

123123
self._ready = 1
@@ -308,6 +308,15 @@ def get_axis_position(self):
308308
zhigh = tc[0][2] > tc[2][2]
309309
return xhigh, yhigh, zhigh
310310

311+
def _on_units_changed(self, scalex=False, scaley=False, scalez=False):
312+
"""
313+
Callback for processing changes to axis units.
314+
315+
Currently forces updates of data limits and view limits.
316+
"""
317+
self.relim()
318+
self.autoscale_view(scalex=scalex, scaley=scaley, scalez=scalez)
319+
311320
def update_datalim(self, xys, **kwargs):
312321
pass
313322

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.