@@ -49,16 +49,17 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
49
49
Minimum length of streamline in axes coordinates.
50
50
51
51
Returns:
52
-
52
+
53
53
*stream_container* : StreamplotSet
54
54
Container object with attributes
55
- lines : `matplotlib.collections.LineCollection` of streamlines
56
- arrows : collection of `matplotlib.patches.FancyArrowPatch` objects
57
- repesenting arrows half-way along stream lines.
58
- This container will probably change in the future to allow changes to
59
- the colormap, alpha, etc. for both lines and arrows, but these changes
60
- should be backward compatible.
61
-
55
+ lines: `matplotlib.collections.LineCollection` of streamlines
56
+ arrows: collection of `matplotlib.patches.FancyArrowPatch`
57
+ objects representing arrows half-way along stream
58
+ lines.
59
+ This container will probably change in the future to allow changes
60
+ to the colormap, alpha, etc. for both lines and arrows, but these
61
+ changes should be backward compatible.
62
+
62
63
"""
63
64
grid = Grid (x , y )
64
65
mask = StreamMask (density )
@@ -71,7 +72,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
71
72
linewidth = matplotlib .rcParams ['lines.linewidth' ]
72
73
73
74
line_kw = {}
74
- arrow_kw = dict (arrowstyle = arrowstyle , mutation_scale = 10 * arrowsize )
75
+ arrow_kw = dict (arrowstyle = arrowstyle , mutation_scale = 10 * arrowsize )
75
76
76
77
use_multicolor_lines = isinstance (color , np .ndarray )
77
78
if use_multicolor_lines :
@@ -104,7 +105,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
104
105
if mask [ym , xm ] == 0 :
105
106
xg , yg = dmap .mask2grid (xm , ym )
106
107
t = integrate (xg , yg )
107
- if t != None :
108
+ if t is not None :
108
109
trajectories .append (t )
109
110
110
111
if use_multicolor_lines :
@@ -128,10 +129,10 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
128
129
streamlines .extend (np .hstack ([points [:- 1 ], points [1 :]]))
129
130
130
131
# Add arrows half way along each trajectory.
131
- s = np .cumsum (np .sqrt (np .diff (tx )** 2 + np .diff (ty )** 2 ))
132
+ s = np .cumsum (np .sqrt (np .diff (tx ) ** 2 + np .diff (ty ) ** 2 ))
132
133
n = np .searchsorted (s , s [- 1 ] / 2. )
133
134
arrow_tail = (tx [n ], ty [n ])
134
- arrow_head = (np .mean (tx [n :n + 2 ]), np .mean (ty [n :n + 2 ]))
135
+ arrow_head = (np .mean (tx [n :n + 2 ]), np .mean (ty [n :n + 2 ]))
135
136
136
137
if isinstance (linewidth , np .ndarray ):
137
138
line_widths = interpgrid (linewidth , tgx , tgy )[:- 1 ]
@@ -143,15 +144,15 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
143
144
line_colors .extend (color_values )
144
145
arrow_kw ['color' ] = cmap (norm (color_values [n ]))
145
146
146
- p = patches .FancyArrowPatch (arrow_tail ,
147
- arrow_head ,
148
- transform = transform ,
147
+ p = patches .FancyArrowPatch (arrow_tail ,
148
+ arrow_head ,
149
+ transform = transform ,
149
150
** arrow_kw )
150
151
axes .add_patch (p )
151
152
arrows .append (p )
152
153
153
- lc = mcollections .LineCollection (streamlines ,
154
- transform = transform ,
154
+ lc = mcollections .LineCollection (streamlines ,
155
+ transform = transform ,
155
156
** line_kw )
156
157
if use_multicolor_lines :
157
158
lc .set_array (np .asarray (line_colors ))
@@ -275,7 +276,7 @@ def within_grid(self, xi, yi):
275
276
"""Return True if point is a valid index of grid."""
276
277
# Note that xi/yi can be floats; so, for example, we can't simply check
277
278
# `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx`
278
- return xi >= 0 and xi <= self .nx - 1 and yi >= 0 and yi <= self .ny - 1
279
+ return xi >= 0 and xi <= self .nx - 1 and yi >= 0 and yi <= self .ny - 1
279
280
280
281
281
282
class StreamMask (object ):
@@ -330,6 +331,7 @@ def _update_trajectory(self, xm, ym):
330
331
class InvalidIndexError (Exception ):
331
332
pass
332
333
334
+
333
335
class TerminateTrajectory (Exception ):
334
336
pass
335
337
@@ -345,7 +347,7 @@ def get_integrator(u, v, dmap, minlength):
345
347
# speed (path length) will be in axes-coordinates
346
348
u_ax = u / dmap .grid .nx
347
349
v_ax = v / dmap .grid .ny
348
- speed = np .ma .sqrt (u_ax ** 2 + v_ax ** 2 )
350
+ speed = np .ma .sqrt (u_ax ** 2 + v_ax ** 2 )
349
351
350
352
def forward_time (xi , yi ):
351
353
ds_dt = interpgrid (speed , xi , yi )
@@ -382,7 +384,7 @@ def integrate(x0, y0):
382
384
383
385
if stotal > minlength :
384
386
return x_traj , y_traj
385
- else : # reject short trajectories
387
+ else : # reject short trajectories
386
388
dmap .undo_trajectory ()
387
389
return None
388
390
@@ -423,7 +425,7 @@ def _integrate_rk12(x0, y0, dmap, f):
423
425
## increment the location gradually. However, due to the efficient
424
426
## nature of the interpolation, this doesn't boost speed by much
425
427
## for quite a bit of complexity.
426
- maxds = min (1. / dmap .mask .nx , 1. / dmap .mask .ny , 0.1 )
428
+ maxds = min (1. / dmap .mask .nx , 1. / dmap .mask .ny , 0.1 )
427
429
428
430
ds = maxds
429
431
stotal = 0
@@ -455,7 +457,7 @@ def _integrate_rk12(x0, y0, dmap, f):
455
457
456
458
nx , ny = dmap .grid .shape
457
459
# Error is normalized to the axes coordinates
458
- error = np .sqrt (((dx2 - dx1 )/ nx )** 2 + ((dy2 - dy1 )/ ny )** 2 )
460
+ error = np .sqrt (((dx2 - dx1 ) / nx ) ** 2 + ((dy2 - dy1 ) / ny ) ** 2 )
459
461
460
462
# Only save step if within error tolerance
461
463
if error < maxerror :
@@ -473,7 +475,7 @@ def _integrate_rk12(x0, y0, dmap, f):
473
475
if error == 0 :
474
476
ds = maxds
475
477
else :
476
- ds = min (maxds , 0.85 * ds * (maxerror / error )** 0.5 )
478
+ ds = min (maxds , 0.85 * ds * (maxerror / error ) ** 0.5 )
477
479
478
480
return stotal , xf_traj , yf_traj
479
481
@@ -497,8 +499,8 @@ def _euler_step(xf_traj, yf_traj, dmap, f):
497
499
else :
498
500
dsy = (ny - 1 - yi ) / cy
499
501
ds = min (dsx , dsy )
500
- xf_traj .append (xi + cx * ds )
501
- yf_traj .append (yi + cy * ds )
502
+ xf_traj .append (xi + cx * ds )
503
+ yf_traj .append (yi + cy * ds )
502
504
return ds , xf_traj , yf_traj
503
505
504
506
@@ -519,10 +521,14 @@ def interpgrid(a, xi, yi):
519
521
x = np .int (xi )
520
522
y = np .int (yi )
521
523
# conditional is faster than clipping for integers
522
- if x == (Nx - 2 ): xn = x
523
- else : xn = x + 1
524
- if y == (Ny - 2 ): yn = y
525
- else : yn = y + 1
524
+ if x == (Nx - 2 ):
525
+ xn = x
526
+ else :
527
+ xn = x + 1
528
+ if y == (Ny - 2 ):
529
+ yn = y
530
+ else :
531
+ yn = y + 1
526
532
527
533
a00 = a [y , x ]
528
534
a01 = a [y , xn ]
@@ -563,20 +569,20 @@ def _gen_starting_points(shape):
563
569
if direction == 'right' :
564
570
x += 1
565
571
if x >= xlast :
566
- xlast -= 1
572
+ xlast -= 1
567
573
direction = 'up'
568
574
elif direction == 'up' :
569
575
y += 1
570
576
if y >= ylast :
571
- ylast -= 1
577
+ ylast -= 1
572
578
direction = 'left'
573
579
elif direction == 'left' :
574
580
x -= 1
575
581
if x <= xfirst :
576
- xfirst += 1
582
+ xfirst += 1
577
583
direction = 'down'
578
584
elif direction == 'down' :
579
585
y -= 1
580
586
if y <= yfirst :
581
- yfirst += 1
587
+ yfirst += 1
582
588
direction = 'right'
0 commit comments