Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 11072f2

Browse filesBrowse files
authored
Merge pull request #12422 from efiring/scatter_color
FIX/API: scatter with invalid data
2 parents 739b86a + a2cec14 commit 11072f2
Copy full SHA for 11072f2

File tree

Expand file treeCollapse file tree

9 files changed

+174
-47
lines changed
Filter options
Expand file treeCollapse file tree

9 files changed

+174
-47
lines changed
+14Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
PathCollections created with `~.Axes.scatter` now keep track of invalid points
2+
``````````````````````````````````````````````````````````````````````````````
3+
4+
Previously, points with nonfinite (infinite or nan) coordinates would not be
5+
included in the offsets (as returned by `PathCollection.get_offsets`) of a
6+
`PathCollection` created by `~.Axes.scatter`, and points with nonfinite values
7+
(as specified by the *c* kwarg) would not be included in the array (as returned
8+
by `PathCollection.get_array`)
9+
10+
Such points are now included, but masked out by returning a masked array.
11+
12+
If the *plotnonfinite* kwarg to `~.Axes.scatter` is set, then points with
13+
nonfinite values are plotted using the bad color of the `PathCollection`\ 's
14+
colormap (as set by `Colormap.set_bad`).

‎examples/units/basic_units.py

Copy file name to clipboardExpand all lines: examples/units/basic_units.py
+18-2Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,10 @@ def get_compressed_copy(self, mask):
174174
def convert_to(self, unit):
175175
if unit == self.unit or not unit:
176176
return self
177-
new_value = self.unit.convert_value_to(self.value, unit)
177+
try:
178+
new_value = self.unit.convert_value_to(self.value, unit)
179+
except AttributeError:
180+
new_value = self
178181
return TaggedValue(new_value, unit)
179182

180183
def get_value(self):
@@ -345,7 +348,20 @@ def convert(val, unit, axis):
345348
if units.ConversionInterface.is_numlike(val):
346349
return val
347350
if np.iterable(val):
348-
return [thisval.convert_to(unit).get_value() for thisval in val]
351+
if isinstance(val, np.ma.MaskedArray):
352+
val = val.astype(float).filled(np.nan)
353+
out = np.empty(len(val))
354+
for i, thisval in enumerate(val):
355+
if np.ma.is_masked(thisval):
356+
out[i] = np.nan
357+
else:
358+
try:
359+
out[i] = thisval.convert_to(unit).get_value()
360+
except AttributeError:
361+
out[i] = thisval
362+
return out
363+
if np.ma.is_masked(val):
364+
return np.nan
349365
else:
350366
return val.convert_to(unit).get_value()
351367

‎examples/units/units_scatter.py

Copy file name to clipboardExpand all lines: examples/units/units_scatter.py
+2-3Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
ax2.scatter(xsecs, xsecs, yunits=hertz)
2828
ax2.axis([0, 10, 0, 1])
2929

30-
ax3.scatter(xsecs, xsecs, yunits=hertz)
31-
ax3.yaxis.set_units(minutes)
32-
ax3.axis([0, 10, 0, 1])
30+
ax3.scatter(xsecs, xsecs, yunits=minutes)
31+
ax3.axis([0, 10, 0, 0.2])
3332

3433
fig.tight_layout()
3534
plt.show()

‎lib/matplotlib/axes/_axes.py

Copy file name to clipboardExpand all lines: lib/matplotlib/axes/_axes.py
+15-8Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4180,7 +4180,7 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41804180
label_namer="y")
41814181
def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
41824182
vmin=None, vmax=None, alpha=None, linewidths=None,
4183-
verts=None, edgecolors=None,
4183+
verts=None, edgecolors=None, *, plotnonfinite=False,
41844184
**kwargs):
41854185
"""
41864186
A scatter plot of *y* vs *x* with varying marker size and/or color.
@@ -4257,6 +4257,10 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
42574257
For non-filled markers, the *edgecolors* kwarg is ignored and
42584258
forced to 'face' internally.
42594259
4260+
plotnonfinite : boolean, optional, default: False
4261+
Set to plot points with nonfinite *c*, in conjunction with
4262+
`~matplotlib.colors.Colormap.set_bad`.
4263+
42604264
Returns
42614265
-------
42624266
paths : `~matplotlib.collections.PathCollection`
@@ -4310,11 +4314,14 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43104314
c, edgecolors, kwargs, xshape, yshape,
43114315
get_next_color_func=self._get_patches_for_fill.get_next_color)
43124316

4313-
# `delete_masked_points` only modifies arguments of the same length as
4314-
# `x`.
4315-
x, y, s, c, colors, edgecolors, linewidths =\
4316-
cbook.delete_masked_points(
4317-
x, y, s, c, colors, edgecolors, linewidths)
4317+
if plotnonfinite and colors is None:
4318+
c = np.ma.masked_invalid(c)
4319+
x, y, s, edgecolors, linewidths = \
4320+
cbook._combine_masks(x, y, s, edgecolors, linewidths)
4321+
else:
4322+
x, y, s, c, colors, edgecolors, linewidths = \
4323+
cbook._combine_masks(
4324+
x, y, s, c, colors, edgecolors, linewidths)
43184325

43194326
scales = s # Renamed for readability below.
43204327

@@ -4340,7 +4347,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43404347
edgecolors = 'face'
43414348
linewidths = rcParams['lines.linewidth']
43424349

4343-
offsets = np.column_stack([x, y])
4350+
offsets = np.ma.column_stack([x, y])
43444351

43454352
collection = mcoll.PathCollection(
43464353
(path,), scales,
@@ -4358,7 +4365,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43584365
if norm is not None and not isinstance(norm, mcolors.Normalize):
43594366
raise ValueError(
43604367
"'norm' must be an instance of 'mcolors.Normalize'")
4361-
collection.set_array(np.asarray(c))
4368+
collection.set_array(c)
43624369
collection.set_cmap(cmap)
43634370
collection.set_norm(norm)
43644371

‎lib/matplotlib/cbook/__init__.py

Copy file name to clipboardExpand all lines: lib/matplotlib/cbook/__init__.py
+60Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,66 @@ def delete_masked_points(*args):
10811081
return margs
10821082

10831083

1084+
def _combine_masks(*args):
1085+
"""
1086+
Find all masked and/or non-finite points in a set of arguments,
1087+
and return the arguments as masked arrays with a common mask.
1088+
1089+
Arguments can be in any of 5 categories:
1090+
1091+
1) 1-D masked arrays
1092+
2) 1-D ndarrays
1093+
3) ndarrays with more than one dimension
1094+
4) other non-string iterables
1095+
5) anything else
1096+
1097+
The first argument must be in one of the first four categories;
1098+
any argument with a length differing from that of the first
1099+
argument (and hence anything in category 5) then will be
1100+
passed through unchanged.
1101+
1102+
Masks are obtained from all arguments of the correct length
1103+
in categories 1, 2, and 4; a point is bad if masked in a masked
1104+
array or if it is a nan or inf. No attempt is made to
1105+
extract a mask from categories 2 and 4 if :meth:`np.isfinite`
1106+
does not yield a Boolean array. Category 3 is included to
1107+
support RGB or RGBA ndarrays, which are assumed to have only
1108+
valid values and which are passed through unchanged.
1109+
1110+
All input arguments that are not passed unchanged are returned
1111+
as masked arrays if any masked points are found, otherwise as
1112+
ndarrays.
1113+
1114+
"""
1115+
if not len(args):
1116+
return ()
1117+
if is_scalar_or_string(args[0]):
1118+
raise ValueError("First argument must be a sequence")
1119+
nrecs = len(args[0])
1120+
margs = [] # Output args; some may be modified.
1121+
seqlist = [False] * len(args) # Flags: True if output will be masked.
1122+
masks = [] # List of masks.
1123+
for i, x in enumerate(args):
1124+
if is_scalar_or_string(x) or len(x) != nrecs:
1125+
margs.append(x) # Leave it unmodified.
1126+
else:
1127+
if isinstance(x, np.ma.MaskedArray) and x.ndim > 1:
1128+
raise ValueError("Masked arrays must be 1-D")
1129+
x = np.asanyarray(x)
1130+
if x.ndim == 1:
1131+
x = safe_masked_invalid(x)
1132+
seqlist[i] = True
1133+
if np.ma.is_masked(x):
1134+
masks.append(np.ma.getmaskarray(x))
1135+
margs.append(x) # Possibly modified.
1136+
if len(masks):
1137+
mask = np.logical_or.reduce(masks)
1138+
for i, x in enumerate(margs):
1139+
if seqlist[i]:
1140+
margs[i] = np.ma.array(x, mask=mask)
1141+
return margs
1142+
1143+
10841144
def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None,
10851145
autorange=False):
10861146
"""

‎lib/matplotlib/pyplot.py

Copy file name to clipboardExpand all lines: lib/matplotlib/pyplot.py
+4-3Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,12 +2835,13 @@ def quiverkey(Q, X, Y, U, label, **kw):
28352835
def scatter(
28362836
x, y, s=None, c=None, marker=None, cmap=None, norm=None,
28372837
vmin=None, vmax=None, alpha=None, linewidths=None, verts=None,
2838-
edgecolors=None, *, data=None, **kwargs):
2838+
edgecolors=None, *, plotnonfinite=False, data=None, **kwargs):
28392839
__ret = gca().scatter(
28402840
x, y, s=s, c=c, marker=marker, cmap=cmap, norm=norm,
28412841
vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,
2842-
verts=verts, edgecolors=edgecolors, **({"data": data} if data
2843-
is not None else {}), **kwargs)
2842+
verts=verts, edgecolors=edgecolors,
2843+
plotnonfinite=plotnonfinite, **({"data": data} if data is not
2844+
None else {}), **kwargs)
28442845
sci(__ret)
28452846
return __ret
28462847

‎lib/matplotlib/testing/decorators.py

Copy file name to clipboardExpand all lines: lib/matplotlib/testing/decorators.py
+31-13Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -452,19 +452,37 @@ def decorator(func):
452452

453453
_, result_dir = map(Path, _image_directories(func))
454454

455-
@pytest.mark.parametrize("ext", extensions)
456-
def wrapper(ext):
457-
fig_test = plt.figure("test")
458-
fig_ref = plt.figure("reference")
459-
func(fig_test, fig_ref)
460-
test_image_path = str(
461-
result_dir / (func.__name__ + "." + ext))
462-
ref_image_path = str(
463-
result_dir / (func.__name__ + "-expected." + ext))
464-
fig_test.savefig(test_image_path)
465-
fig_ref.savefig(ref_image_path)
466-
_raise_on_image_difference(
467-
ref_image_path, test_image_path, tol=tol)
455+
if len(inspect.signature(func).parameters) == 2:
456+
# Free-standing function.
457+
@pytest.mark.parametrize("ext", extensions)
458+
def wrapper(ext):
459+
fig_test = plt.figure("test")
460+
fig_ref = plt.figure("reference")
461+
func(fig_test, fig_ref)
462+
test_image_path = str(
463+
result_dir / (func.__name__ + "." + ext))
464+
ref_image_path = str(
465+
result_dir / (func.__name__ + "-expected." + ext))
466+
fig_test.savefig(test_image_path)
467+
fig_ref.savefig(ref_image_path)
468+
_raise_on_image_difference(
469+
ref_image_path, test_image_path, tol=tol)
470+
471+
elif len(inspect.signature(func).parameters) == 3:
472+
# Method.
473+
@pytest.mark.parametrize("ext", extensions)
474+
def wrapper(self, ext):
475+
fig_test = plt.figure("test")
476+
fig_ref = plt.figure("reference")
477+
func(self, fig_test, fig_ref)
478+
test_image_path = str(
479+
result_dir / (func.__name__ + "." + ext))
480+
ref_image_path = str(
481+
result_dir / (func.__name__ + "-expected." + ext))
482+
fig_test.savefig(test_image_path)
483+
fig_ref.savefig(ref_image_path)
484+
_raise_on_image_difference(
485+
ref_image_path, test_image_path, tol=tol)
468486

469487
return wrapper
470488

‎lib/matplotlib/tests/test_axes.py

Copy file name to clipboardExpand all lines: lib/matplotlib/tests/test_axes.py
+28-15Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,34 @@ def test_scatter_color(self):
17491749
with pytest.raises(ValueError):
17501750
plt.scatter([1, 2, 3], [1, 2, 3], color=[1, 2, 3])
17511751

1752+
@check_figures_equal(extensions=["png"])
1753+
def test_scatter_invalid_color(self, fig_test, fig_ref):
1754+
ax = fig_test.subplots()
1755+
cmap = plt.get_cmap("viridis", 16)
1756+
cmap.set_bad("k", 1)
1757+
# Set a nonuniform size to prevent the last call to `scatter` (plotting
1758+
# the invalid points separately in fig_ref) from using the marker
1759+
# stamping fast path, which would result in slightly offset markers.
1760+
ax.scatter(range(4), range(4),
1761+
c=[1, np.nan, 2, np.nan], s=[1, 2, 3, 4],
1762+
cmap=cmap, plotnonfinite=True)
1763+
ax = fig_ref.subplots()
1764+
cmap = plt.get_cmap("viridis", 16)
1765+
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)
1766+
ax.scatter([1, 3], [1, 3], s=[2, 4], color="k")
1767+
1768+
@check_figures_equal(extensions=["png"])
1769+
def test_scatter_no_invalid_color(self, fig_test, fig_ref):
1770+
# With plotninfinite=False we plot only 2 points.
1771+
ax = fig_test.subplots()
1772+
cmap = plt.get_cmap("viridis", 16)
1773+
cmap.set_bad("k", 1)
1774+
ax.scatter(range(4), range(4),
1775+
c=[1, np.nan, 2, np.nan], s=[1, 2, 3, 4],
1776+
cmap=cmap, plotnonfinite=False)
1777+
ax = fig_ref.subplots()
1778+
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)
1779+
17521780
# Parameters for *test_scatter_c*. NB: assuming that the
17531781
# scatter plot will have 4 elements. The tuple scheme is:
17541782
# (*c* parameter case, exception regexp key or None if no exception)
@@ -5743,21 +5771,6 @@ def test_color_length_mismatch():
57435771
ax.scatter(x, y, c=[c_rgb] * N)
57445772

57455773

5746-
def test_scatter_color_masking():
5747-
x = np.array([1, 2, 3])
5748-
y = np.array([1, np.nan, 3])
5749-
colors = np.array(['k', 'w', 'k'])
5750-
linewidths = np.array([1, 2, 3])
5751-
s = plt.scatter(x, y, color=colors, linewidths=linewidths)
5752-
5753-
facecolors = s.get_facecolors()
5754-
linecolors = s.get_edgecolors()
5755-
linewidths = s.get_linewidths()
5756-
assert_array_equal(facecolors[1], np.array([0, 0, 0, 1]))
5757-
assert_array_equal(linecolors[1], np.array([0, 0, 0, 1]))
5758-
assert linewidths[1] == 3
5759-
5760-
57615774
def test_eventplot_legend():
57625775
plt.eventplot([1.0], label='Label')
57635776
plt.legend()

‎lib/matplotlib/tests/test_colorbar.py

Copy file name to clipboardExpand all lines: lib/matplotlib/tests/test_colorbar.py
+2-3Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,8 @@ def test_colorbar_single_scatter():
197197
# the norm scaling within the colorbar must ensure a
198198
# finite range, otherwise a zero denominator will occur in _locate.
199199
plt.figure()
200-
x = np.arange(4)
201-
y = x.copy()
202-
z = np.ma.masked_greater(np.arange(50, 54), 50)
200+
x = y = [0]
201+
z = [50]
203202
cmap = plt.get_cmap('jet', 16)
204203
cs = plt.scatter(x, y, z, c=z, cmap=cmap)
205204
plt.colorbar(cs)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.