@@ -108,41 +108,43 @@ def update_lim(self, axes):
108
108
delta2 = _api .deprecated ("3.6" )(
109
109
property (lambda self : 0.00001 , lambda self , value : None ))
110
110
111
+ def _to_xy (self , * , variable , fixed ):
112
+ """
113
+ Stack the *variable* and *fixed* array-likes along the last axis so
114
+ that *variable* is at the ``self.nth_coord`` position.
115
+ """
116
+ if self .nth_coord == 0 :
117
+ return np .stack (np .broadcast_arrays (variable , fixed ), axis = - 1 )
118
+ elif self .nth_coord == 1 :
119
+ return np .stack (np .broadcast_arrays (fixed , variable ), axis = - 1 )
120
+ else :
121
+ raise ValueError ("Unxpected nth_coord" )
122
+
111
123
class Fixed (_Base ):
112
124
"""Helper class for a fixed (in the axes coordinate) axis."""
113
125
126
+ # deprecated with passthru_pt
114
127
_default_passthru_pt = dict (left = (0 , 0 ),
115
128
right = (1 , 0 ),
116
129
bottom = (0 , 0 ),
117
130
top = (0 , 1 ))
131
+ passthru_pt = _api .deprecated ("3.6" )(property (
132
+ lambda self : self ._default_passthru_pt [self ._loc ]))
118
133
119
134
def __init__ (self , loc , nth_coord = None ):
120
135
"""
121
136
nth_coord = along which coordinate value varies
122
137
in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
123
138
"""
124
- _api .check_in_list (["left" , "right" , "bottom" , "top" ], loc = loc )
125
139
self ._loc = loc
126
-
127
- if nth_coord is None :
128
- if loc in ["left" , "right" ]:
129
- nth_coord = 1
130
- elif loc in ["bottom" , "top" ]:
131
- nth_coord = 0
132
-
133
- self .nth_coord = nth_coord
134
-
140
+ self ._pos = _api .check_getitem (
141
+ {"bottom" : 0 , "top" : 1 , "left" : 0 , "right" : 1 }, loc = loc )
142
+ self .nth_coord = (
143
+ nth_coord if nth_coord is not None else
144
+ {"bottom" : 0 , "top" : 0 , "left" : 1 , "right" : 1 }[loc ])
135
145
super ().__init__ ()
136
-
137
- self .passthru_pt = self ._default_passthru_pt [loc ]
138
-
139
- _verts = np .array ([[0. , 0. ],
140
- [1. , 1. ]])
141
- fixed_coord = 1 - nth_coord
142
- _verts [:, fixed_coord ] = self .passthru_pt [fixed_coord ]
143
-
144
146
# axis line in transAxes
145
- self ._path = Path (_verts )
147
+ self ._path = Path (self . _to_xy ( variable = ( 0 , 1 ), fixed = self . _pos ) )
146
148
147
149
def get_nth_coord (self ):
148
150
return self .nth_coord
@@ -225,8 +227,7 @@ def get_tick_iterators(self, axes):
225
227
226
228
def _f (locs , labels ):
227
229
for x , l in zip (locs , labels ):
228
- c = list (self .passthru_pt ) # copy
229
- c [self .nth_coord ] = x
230
+ c = self ._to_xy (variable = x , fixed = self ._pos )
230
231
# check if the tick point is inside axes
231
232
c2 = tick_to_axes .transform (c )
232
233
if mpl .transforms ._interval_contains_close (
@@ -243,15 +244,10 @@ def __init__(self, axes, nth_coord,
243
244
self .axis = [axes .xaxis , axes .yaxis ][self .nth_coord ]
244
245
245
246
def get_line (self , axes ):
246
- _verts = np .array ([[0. , 0. ],
247
- [1. , 1. ]])
248
-
249
247
fixed_coord = 1 - self .nth_coord
250
248
data_to_axes = axes .transData - axes .transAxes
251
249
p = data_to_axes .transform ([self ._value , self ._value ])
252
- _verts [:, fixed_coord ] = p [fixed_coord ]
253
-
254
- return Path (_verts )
250
+ return Path (self ._to_xy (variable = (0 , 1 ), fixed = p [fixed_coord ]))
255
251
256
252
def get_line_transform (self , axes ):
257
253
return axes .transAxes
@@ -266,13 +262,12 @@ def get_axislabel_pos_angle(self, axes):
266
262
get_label_transform() returns a transform of (transAxes+offset)
267
263
"""
268
264
angle = [0 , 90 ][self .nth_coord ]
269
- _verts = [0.5 , 0.5 ]
270
265
fixed_coord = 1 - self .nth_coord
271
266
data_to_axes = axes .transData - axes .transAxes
272
267
p = data_to_axes .transform ([self ._value , self ._value ])
273
- _verts [ fixed_coord ] = p [fixed_coord ]
274
- if 0 <= _verts [fixed_coord ] <= 1 :
275
- return _verts , angle
268
+ verts = self . _to_xy ( variable = 0.5 , fixed = p [fixed_coord ])
269
+ if 0 <= verts [fixed_coord ] <= 1 :
270
+ return verts , angle
276
271
else :
277
272
return None , None
278
273
@@ -298,8 +293,7 @@ def get_tick_iterators(self, axes):
298
293
299
294
def _f (locs , labels ):
300
295
for x , l in zip (locs , labels ):
301
- c = [self ._value , self ._value ]
302
- c [self .nth_coord ] = x
296
+ c = self ._to_xy (variable = x , fixed = self ._value )
303
297
c1 , c2 = data_to_axes .transform (c )
304
298
if 0 <= c1 <= 1 and 0 <= c2 <= 1 :
305
299
yield c , angle_normal , angle_tangent , l
0 commit comments