Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 697c240

Browse filesBrowse files
authored
Merge pull request #29954 from 34j/perf/better-multicolored-line
Simplify `colored_line()` implementation in Multicolored lines example
2 parents 76a47d9 + 05c622b commit 697c240
Copy full SHA for 697c240

File tree

Expand file treeCollapse file tree

1 file changed

+27
-36
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+27
-36
lines changed

‎galleries/examples/lines_bars_and_markers/multicolored_line.py

Copy file name to clipboardExpand all lines: galleries/examples/lines_bars_and_markers/multicolored_line.py
+27-36Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from matplotlib.collections import LineCollection
2222

2323

24-
def colored_line(x, y, c, ax, **lc_kwargs):
24+
def colored_line(x, y, c, ax=None, **lc_kwargs):
2525
"""
2626
Plot a line with a color specified along the line by a third value.
2727
@@ -36,8 +36,8 @@ def colored_line(x, y, c, ax, **lc_kwargs):
3636
The horizontal and vertical coordinates of the data points.
3737
c : array-like
3838
The color values, which should be the same size as x and y.
39-
ax : Axes
40-
Axis object on which to plot the colored line.
39+
ax : matplotlib.axes.Axes, optional
40+
The axes to plot on. If not provided, the current axes will be used.
4141
**lc_kwargs
4242
Any additional arguments to pass to matplotlib.collections.LineCollection
4343
constructor. This should not include the array keyword argument because
@@ -49,36 +49,32 @@ def colored_line(x, y, c, ax, **lc_kwargs):
4949
The generated line collection representing the colored line.
5050
"""
5151
if "array" in lc_kwargs:
52-
warnings.warn('The provided "array" keyword argument will be overridden')
52+
warnings.warn(
53+
'The provided "array" keyword argument will be overridden',
54+
UserWarning,
55+
stacklevel=2,
56+
)
57+
58+
xy = np.stack((x, y), axis=-1)
59+
xy_mid = np.concat(
60+
(xy[0, :][None, :], (xy[:-1, :] + xy[1:, :]) / 2, xy[-1, :][None, :]), axis=0
61+
)
62+
segments = np.stack((xy_mid[:-1, :], xy, xy_mid[1:, :]), axis=-2)
63+
# Note that
64+
# segments[0, :, :] is [xy[0, :], xy[0, :], (xy[0, :] + xy[1, :]) / 2]
65+
# segments[i, :, :] is [(xy[i - 1, :] + xy[i, :]) / 2, xy[i, :],
66+
# (xy[i, :] + xy[i + 1, :]) / 2] if i not in {0, len(x) - 1}
67+
# segments[-1, :, :] is [(xy[-2, :] + xy[-1, :]) / 2, xy[-1, :], xy[-1, :]]
68+
69+
lc_kwargs["array"] = c
70+
lc = LineCollection(segments, **lc_kwargs)
5371

54-
# Default the capstyle to butt so that the line segments smoothly line up
55-
default_kwargs = {"capstyle": "butt"}
56-
default_kwargs.update(lc_kwargs)
57-
58-
# Compute the midpoints of the line segments. Include the first and last points
59-
# twice so we don't need any special syntax later to handle them.
60-
x = np.asarray(x)
61-
y = np.asarray(y)
62-
x_midpts = np.hstack((x[0], 0.5 * (x[1:] + x[:-1]), x[-1]))
63-
y_midpts = np.hstack((y[0], 0.5 * (y[1:] + y[:-1]), y[-1]))
64-
65-
# Determine the start, middle, and end coordinate pair of each line segment.
66-
# Use the reshape to add an extra dimension so each pair of points is in its
67-
# own list. Then concatenate them to create:
68-
# [
69-
# [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)],
70-
# [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)],
71-
# ...
72-
# ]
73-
coord_start = np.column_stack((x_midpts[:-1], y_midpts[:-1]))[:, np.newaxis, :]
74-
coord_mid = np.column_stack((x, y))[:, np.newaxis, :]
75-
coord_end = np.column_stack((x_midpts[1:], y_midpts[1:]))[:, np.newaxis, :]
76-
segments = np.concatenate((coord_start, coord_mid, coord_end), axis=1)
77-
78-
lc = LineCollection(segments, **default_kwargs)
79-
lc.set_array(c) # set the colors of each segment
72+
# Plot the line collection to the axes
73+
ax = ax or plt.gca()
74+
ax.add_collection(lc)
75+
ax.autoscale_view()
8076

81-
return ax.add_collection(lc)
77+
return lc
8278

8379

8480
# -------------- Create and show plot --------------
@@ -93,11 +89,6 @@ def colored_line(x, y, c, ax, **lc_kwargs):
9389
lines = colored_line(x, y, color, ax1, linewidth=10, cmap="plasma")
9490
fig1.colorbar(lines) # add a color legend
9591

96-
# Set the axis limits and tick positions
97-
ax1.set_xlim(-1, 1)
98-
ax1.set_ylim(-1, 1)
99-
ax1.set_xticks((-1, 0, 1))
100-
ax1.set_yticks((-1, 0, 1))
10192
ax1.set_title("Color at each point")
10293

10394
plt.show()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.