Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e3c9376

Browse filesBrowse files
authored
Merge pull request #13959 from efiring/scatter_ravel_consistency
Scatter: make "c" and "s" argument handling more consistent.
2 parents 71771ea + c1add44 commit e3c9376
Copy full SHA for e3c9376

File tree

Expand file treeCollapse file tree

2 files changed

+49
-38
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+49
-38
lines changed

‎lib/matplotlib/axes/_axes.py

Copy file name to clipboardExpand all lines: lib/matplotlib/axes/_axes.py
+21-28Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4138,7 +4138,7 @@ def dopatch(xs, ys, **kwargs):
41384138
medians=medians, fliers=fliers, means=means)
41394139

41404140
@staticmethod
4141-
def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
4141+
def _parse_scatter_color_args(c, edgecolors, kwargs, xsize,
41424142
get_next_color_func):
41434143
"""
41444144
Helper function to process color related arguments of `.Axes.scatter`.
@@ -4168,8 +4168,8 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41684168
Additional kwargs. If these keys exist, we pop and process them:
41694169
'facecolors', 'facecolor', 'edgecolor', 'color'
41704170
Note: The dict is modified by this function.
4171-
xshape, yshape : tuple of int
4172-
The shape of the x and y arrays passed to `.Axes.scatter`.
4171+
xsize : int
4172+
The size of the x and y arrays passed to `.Axes.scatter`.
41734173
get_next_color_func : callable
41744174
A callable that returns a color. This color is used as facecolor
41754175
if no other color is provided.
@@ -4192,9 +4192,6 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41924192
The edgecolor specification.
41934193
41944194
"""
4195-
xsize = functools.reduce(operator.mul, xshape, 1)
4196-
ysize = functools.reduce(operator.mul, yshape, 1)
4197-
41984195
facecolors = kwargs.pop('facecolors', None)
41994196
facecolors = kwargs.pop('facecolor', facecolors)
42004197
edgecolors = kwargs.pop('edgecolor', edgecolors)
@@ -4234,7 +4231,7 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42344231
# favor of mapping, not rgb or rgba.
42354232
# Convenience vars to track shape mismatch *and* conversion failures.
42364233
valid_shape = True # will be put to the test!
4237-
n_elem = -1 # used only for (some) exceptions
4234+
csize = -1 # Number of colors; used for some exceptions.
42384235

42394236
if (c_was_none or
42404237
kwcolor is not None or
@@ -4246,9 +4243,9 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42464243
else:
42474244
try: # First, does 'c' look suitable for value-mapping?
42484245
c_array = np.asanyarray(c, dtype=float)
4249-
n_elem = c_array.shape[0]
4250-
if c_array.shape in [xshape, yshape]:
4251-
c = np.ma.ravel(c_array)
4246+
csize = c_array.size
4247+
if csize == xsize:
4248+
c = c_array.ravel()
42524249
else:
42534250
if c_array.shape in ((3,), (4,)):
42544251
_log.warning(
@@ -4267,28 +4264,23 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42674264
if c_array is None:
42684265
try: # Then is 'c' acceptable as PathCollection facecolors?
42694266
colors = mcolors.to_rgba_array(c)
4270-
n_elem = colors.shape[0]
4271-
if colors.shape[0] not in (0, 1, xsize, ysize):
4267+
csize = colors.shape[0]
4268+
if csize not in (0, 1, xsize):
42724269
# NB: remember that a single color is also acceptable.
42734270
# Besides *colors* will be an empty array if c == 'none'.
42744271
valid_shape = False
42754272
raise ValueError
42764273
except ValueError:
42774274
if not valid_shape: # but at least one conversion succeeded.
42784275
raise ValueError(
4279-
"'c' argument has {nc} elements, which is not "
4280-
"acceptable for use with 'x' with size {xs}, "
4281-
"'y' with size {ys}."
4282-
.format(nc=n_elem, xs=xsize, ys=ysize)
4283-
)
4276+
f"'c' argument has {csize} elements, which is "
4277+
"inconsistent with 'x' and 'y' with size {xsize}.")
42844278
else:
42854279
# Both the mapping *and* the RGBA conversion failed: pretty
42864280
# severe failure => one may appreciate a verbose feedback.
42874281
raise ValueError(
4288-
"'c' argument must be a mpl color, a sequence of mpl "
4289-
"colors or a sequence of numbers, not {}."
4290-
.format(c) # note: could be long depending on c
4291-
)
4282+
f"'c' argument must be a mpl color, a sequence of mpl "
4283+
"colors, or a sequence of numbers, not {c}.")
42924284
else:
42934285
colors = None # use cmap, norm after collection is created
42944286
return c, colors, edgecolors
@@ -4306,7 +4298,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43064298
43074299
Parameters
43084300
----------
4309-
x, y : array_like, shape (n, )
4301+
x, y : scalar or array_like, shape (n, )
43104302
The data positions.
43114303
43124304
s : scalar or array_like, shape (n, ), optional
@@ -4318,8 +4310,8 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43184310
43194311
- A single color format string.
43204312
- A sequence of color specifications of length n.
4321-
- A sequence of n numbers to be mapped to colors using *cmap* and
4322-
*norm*.
4313+
- A scalar or sequence of n numbers to be mapped to colors using
4314+
*cmap* and *norm*.
43234315
- A 2-D array in which the rows are RGB or RGBA.
43244316
43254317
Note that *c* should not be a single numeric RGB or RGBA sequence
@@ -4408,7 +4400,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44084400
plotted.
44094401
44104402
* Fundamentally, scatter works with 1-D arrays; *x*, *y*, *s*, and *c*
4411-
may be input as 2-D arrays, but within scatter they will be
4403+
may be input as N-D arrays, but within scatter they will be
44124404
flattened. The exception is *c*, which will be flattened only if its
44134405
size matches the size of *x* and *y*.
44144406
@@ -4421,7 +4413,6 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44214413

44224414
# np.ma.ravel yields an ndarray, not a masked array,
44234415
# unless its argument is a masked array.
4424-
xshape, yshape = np.shape(x), np.shape(y)
44254416
x = np.ma.ravel(x)
44264417
y = np.ma.ravel(y)
44274418
if x.size != y.size:
@@ -4430,11 +4421,13 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44304421
if s is None:
44314422
s = (20 if rcParams['_internal.classic_mode'] else
44324423
rcParams['lines.markersize'] ** 2.0)
4433-
s = np.ma.ravel(s) # This doesn't have to match x, y in size.
4424+
s = np.ma.ravel(s)
4425+
if len(s) not in (1, x.size):
4426+
raise ValueError("s must be a scalar, or the same size as x and y")
44344427

44354428
c, colors, edgecolors = \
44364429
self._parse_scatter_color_args(
4437-
c, edgecolors, kwargs, xshape, yshape,
4430+
c, edgecolors, kwargs, x.size,
44384431
get_next_color_func=self._get_patches_for_fill.get_next_color)
44394432

44404433
if plotnonfinite and colors is None:

‎lib/matplotlib/tests/test_axes.py

Copy file name to clipboardExpand all lines: lib/matplotlib/tests/test_axes.py
+28-10Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,13 @@ def test_scatter_color(self):
18201820
with pytest.raises(ValueError):
18211821
plt.scatter([1, 2, 3], [1, 2, 3], color=[1, 2, 3])
18221822

1823+
def test_scatter_size_arg_size(self):
1824+
x = np.arange(4)
1825+
with pytest.raises(ValueError):
1826+
plt.scatter(x, x, x[1:])
1827+
with pytest.raises(ValueError):
1828+
plt.scatter(x[1:], x[1:], x)
1829+
18231830
@check_figures_equal(extensions=["png"])
18241831
def test_scatter_invalid_color(self, fig_test, fig_ref):
18251832
ax = fig_test.subplots()
@@ -1848,6 +1855,21 @@ def test_scatter_no_invalid_color(self, fig_test, fig_ref):
18481855
ax = fig_ref.subplots()
18491856
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)
18501857

1858+
@check_figures_equal(extensions=["png"])
1859+
def test_scatter_single_point(self, fig_test, fig_ref):
1860+
ax = fig_test.subplots()
1861+
ax.scatter(1, 1, c=1)
1862+
ax = fig_ref.subplots()
1863+
ax.scatter([1], [1], c=[1])
1864+
1865+
@check_figures_equal(extensions=["png"])
1866+
def test_scatter_different_shapes(self, fig_test, fig_ref):
1867+
x = np.arange(10)
1868+
ax = fig_test.subplots()
1869+
ax.scatter(x, x.reshape(2, 5), c=x.reshape(5, 2))
1870+
ax = fig_ref.subplots()
1871+
ax.scatter(x.reshape(5, 2), x, c=x.reshape(2, 5))
1872+
18511873
# Parameters for *test_scatter_c*. NB: assuming that the
18521874
# scatter plot will have 4 elements. The tuple scheme is:
18531875
# (*c* parameter case, exception regexp key or None if no exception)
@@ -1904,7 +1926,7 @@ def get_next_color():
19041926

19051927
from matplotlib.axes import Axes
19061928

1907-
xshape = yshape = (4,)
1929+
xsize = 4
19081930

19091931
# Additional checking of *c* (introduced in #11383).
19101932
REGEXP = {
@@ -1914,21 +1936,18 @@ def get_next_color():
19141936

19151937
if re_key is None:
19161938
Axes._parse_scatter_color_args(
1917-
c=c_case, edgecolors="black", kwargs={},
1918-
xshape=xshape, yshape=yshape,
1939+
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
19191940
get_next_color_func=get_next_color)
19201941
else:
19211942
with pytest.raises(ValueError, match=REGEXP[re_key]):
19221943
Axes._parse_scatter_color_args(
1923-
c=c_case, edgecolors="black", kwargs={},
1924-
xshape=xshape, yshape=yshape,
1944+
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
19251945
get_next_color_func=get_next_color)
19261946

19271947

1928-
def _params(c=None, xshape=(2,), yshape=(2,), **kwargs):
1948+
def _params(c=None, xsize=2, **kwargs):
19291949
edgecolors = kwargs.pop('edgecolors', None)
1930-
return (c, edgecolors, kwargs if kwargs is not None else {},
1931-
xshape, yshape)
1950+
return (c, edgecolors, kwargs if kwargs is not None else {}, xsize)
19321951
_result = namedtuple('_result', 'c, colors')
19331952

19341953

@@ -1980,8 +1999,7 @@ def get_next_color():
19801999
c = kwargs.pop('c', None)
19812000
edgecolors = kwargs.pop('edgecolors', None)
19822001
_, _, result_edgecolors = \
1983-
Axes._parse_scatter_color_args(c, edgecolors, kwargs,
1984-
xshape=(2,), yshape=(2,),
2002+
Axes._parse_scatter_color_args(c, edgecolors, kwargs, xsize=2,
19852003
get_next_color_func=get_next_color)
19862004
assert result_edgecolors == expected_edgecolors
19872005

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.