21
21
from matplotlib .collections import LineCollection
22
22
23
23
24
- def colored_line (x , y , c , ax , ** lc_kwargs ):
24
+ def colored_line (x , y , c , ax = None , scalex = True , scaley = True , ** lc_kwargs ):
25
25
"""
26
26
Plot a line with a color specified along the line by a third value.
27
27
@@ -36,9 +36,12 @@ def colored_line(x, y, c, ax, **lc_kwargs):
36
36
The horizontal and vertical coordinates of the data points.
37
37
c : array-like
38
38
The color values, which should be the same size as x and y.
39
- ax : Axes
40
- Axis object on which to plot the colored line.
41
- **lc_kwargs
39
+ ax : matplotlib.axes.Axes, optional
40
+ The axes to plot on. If not provided, the current axes will be used.
41
+ scalex, scaley : bool
42
+ These parameters determine if the view limits are adapted to the data limits.
43
+ The values are passed on to autoscale_view.
44
+ **lc_kwargs : Any
42
45
Any additional arguments to pass to matplotlib.collections.LineCollection
43
46
constructor. This should not include the array keyword argument because
44
47
that is set to the color argument. If provided, it will be overridden.
@@ -49,36 +52,34 @@ def colored_line(x, y, c, ax, **lc_kwargs):
49
52
The generated line collection representing the colored line.
50
53
"""
51
54
if "array" in lc_kwargs :
52
- warnings .warn ('The provided "array" keyword argument will be overridden' )
55
+ warnings .warn (
56
+ 'The provided "array" keyword argument will be overridden' ,
57
+ UserWarning ,
58
+ stacklevel = 2 ,
59
+ )
53
60
54
- # Default the capstyle to butt so that the line segments smoothly line up
55
- default_kwargs = {"capstyle" : "butt" }
56
- default_kwargs .update (lc_kwargs )
57
-
58
- # Compute the midpoints of the line segments. Include the first and last points
59
- # twice so we don't need any special syntax later to handle them.
60
- x = np .asarray (x )
61
- y = np .asarray (y )
62
- x_midpts = np .hstack ((x [0 ], 0.5 * (x [1 :] + x [:- 1 ]), x [- 1 ]))
63
- y_midpts = np .hstack ((y [0 ], 0.5 * (y [1 :] + y [:- 1 ]), y [- 1 ]))
64
-
65
- # Determine the start, middle, and end coordinate pair of each line segment.
66
- # Use the reshape to add an extra dimension so each pair of points is in its
67
- # own list. Then concatenate them to create:
68
- # [
69
- # [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)],
70
- # [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)],
61
+ xy = np .stack ((x , y ), axis = - 1 )
62
+ xy_mid = np .concat (
63
+ (xy [0 , :][None , :], (xy [:- 1 , :] + xy [1 :, :]) / 2 , xy [- 1 , :][None , :]), axis = 0
64
+ )
65
+ segments = np .stack ((xy_mid [:- 1 , :], xy , xy_mid [1 :, :]), axis = - 2 )
66
+ # Note that segments is [
67
+ # [[x[0], y[0]], [x[0], y[0]], [mean(x[0], x[1]), mean(y[0], y[1])]],
68
+ # [[mean(x[0], x[1]), mean(y[0], y[1])], [x[1], y[1]], [mean(x[1], x[2]), mean(y[1], y[2])]],
71
69
# ...
70
+ # [[mean(x[-2], x[-1]), mean(y[-2], y[-1])], [x[-1], y[-1]], [x[-1], y[-1]]]
72
71
# ]
73
- coord_start = np .column_stack ((x_midpts [:- 1 ], y_midpts [:- 1 ]))[:, np .newaxis , :]
74
- coord_mid = np .column_stack ((x , y ))[:, np .newaxis , :]
75
- coord_end = np .column_stack ((x_midpts [1 :], y_midpts [1 :]))[:, np .newaxis , :]
76
- segments = np .concatenate ((coord_start , coord_mid , coord_end ), axis = 1 )
77
72
78
- lc = LineCollection (segments , ** default_kwargs )
79
- lc .set_array (c ) # set the colors of each segment
73
+ lc_kwargs ["array" ] = c
74
+ lc = LineCollection (segments , ** lc_kwargs )
75
+
76
+ # Plot the line collection to the axes
77
+ ax = ax or plt .gca ()
78
+ ax .add_collection (lc )
79
+ ax .autoscale_view (scalex = scalex , scaley = scaley )
80
80
81
- return ax .add_collection (lc )
81
+ # Return the LineCollection object
82
+ return lc
82
83
83
84
84
85
# -------------- Create and show plot --------------
@@ -93,11 +94,6 @@ def colored_line(x, y, c, ax, **lc_kwargs):
93
94
lines = colored_line (x , y , color , ax1 , linewidth = 10 , cmap = "plasma" )
94
95
fig1 .colorbar (lines ) # add a color legend
95
96
96
- # Set the axis limits and tick positions
97
- ax1 .set_xlim (- 1 , 1 )
98
- ax1 .set_ylim (- 1 , 1 )
99
- ax1 .set_xticks ((- 1 , 0 , 1 ))
100
- ax1 .set_yticks ((- 1 , 0 , 1 ))
101
97
ax1 .set_title ("Color at each point" )
102
98
103
99
plt .show ()
0 commit comments