Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 84e0fab

Browse filesBrowse files
authored
Merge pull request #20390 from anntzer/arrow_demo
Cleanup arrow_demo.
2 parents 6877d16 + 3eaed9e commit 84e0fab
Copy full SHA for 84e0fab

File tree

Expand file treeCollapse file tree

1 file changed

+81
-229
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+81
-229
lines changed

‎examples/text_labels_and_annotations/arrow_demo.py

Copy file name to clipboardExpand all lines: examples/text_labels_and_annotations/arrow_demo.py
+81-229Lines changed: 81 additions & 229 deletions
Original file line numberDiff line numberDiff line change
@@ -3,153 +3,84 @@
33
Arrow Demo
44
==========
55
6-
Arrow drawing example for the new fancy_arrow facilities.
7-
8-
Code contributed by: Rob Knight <rob@spot.colorado.edu>
9-
10-
usage:
11-
12-
python arrow_demo.py realistic|full|sample|extreme
6+
Three ways of drawing arrows to encode arrow "strength" (e.g., transition
7+
probabilities in a Markov model) using arrow length, width, or alpha (opacity).
8+
"""
139

10+
import itertools
1411

15-
"""
1612
import matplotlib.pyplot as plt
1713
import numpy as np
1814

19-
rates_to_bases = {'r1': 'AT', 'r2': 'TA', 'r3': 'GA', 'r4': 'AG', 'r5': 'CA',
20-
'r6': 'AC', 'r7': 'GT', 'r8': 'TG', 'r9': 'CT', 'r10': 'TC',
21-
'r11': 'GC', 'r12': 'CG'}
22-
numbered_bases_to_rates = {v: k for k, v in rates_to_bases.items()}
23-
lettered_bases_to_rates = {v: 'r' + v for k, v in rates_to_bases.items()}
24-
2515

26-
def make_arrow_plot(data, size=4, display='length', shape='right',
27-
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
28-
normalize_data=False, ec=None, labelcolor=None,
29-
head_starts_at_zero=True,
30-
rate_labels=lettered_bases_to_rates,
31-
**kwargs):
16+
def make_arrow_graph(ax, data, size=4, display='length', shape='right',
17+
max_arrow_width=0.03, arrow_sep=0.02, alpha=0.5,
18+
normalize_data=False, ec=None, labelcolor=None,
19+
**kwargs):
3220
"""
3321
Makes an arrow plot.
3422
3523
Parameters
3624
----------
25+
ax
26+
The axes where the graph is drawn.
3727
data
3828
Dict with probabilities for the bases and pair transitions.
3929
size
40-
Size of the graph in inches.
30+
Size of the plot, in inches.
4131
display : {'length', 'width', 'alpha'}
4232
The arrow property to change.
4333
shape : {'full', 'left', 'right'}
4434
For full or half arrows.
4535
max_arrow_width : float
46-
Maximum width of an arrow, data coordinates.
36+
Maximum width of an arrow, in data coordinates.
4737
arrow_sep : float
48-
Separation between arrows in a pair, data coordinates.
38+
Separation between arrows in a pair, in data coordinates.
4939
alpha : float
5040
Maximum opacity of arrows.
5141
**kwargs
52-
Can be anything allowed by a Arrow object, e.g. *linewidth* or
53-
*edgecolor*.
42+
`.FancyArrow` properties, e.g. *linewidth* or *edgecolor*.
5443
"""
5544

56-
plt.xlim(-0.5, 1.5)
57-
plt.ylim(-0.5, 1.5)
58-
plt.gcf().set_size_inches(size, size)
59-
plt.xticks([])
60-
plt.yticks([])
45+
ax.set(xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), xticks=[], yticks=[])
46+
ax.text(.01, .01, f'flux encoded as arrow {display}',
47+
transform=ax.transAxes)
6148
max_text_size = size * 12
6249
min_text_size = size
6350
label_text_size = size * 2.5
64-
text_params = {'ha': 'center', 'va': 'center', 'family': 'sans-serif',
65-
'fontweight': 'bold'}
66-
r2 = np.sqrt(2)
67-
68-
deltas = {
69-
'AT': (1, 0),
70-
'TA': (-1, 0),
71-
'GA': (0, 1),
72-
'AG': (0, -1),
73-
'CA': (-1 / r2, 1 / r2),
74-
'AC': (1 / r2, -1 / r2),
75-
'GT': (1 / r2, 1 / r2),
76-
'TG': (-1 / r2, -1 / r2),
77-
'CT': (0, 1),
78-
'TC': (0, -1),
79-
'GC': (1, 0),
80-
'CG': (-1, 0)}
8151

82-
colors = {
83-
'AT': 'r',
84-
'TA': 'k',
85-
'GA': 'g',
86-
'AG': 'r',
87-
'CA': 'b',
88-
'AC': 'r',
89-
'GT': 'g',
90-
'TG': 'k',
91-
'CT': 'b',
92-
'TC': 'k',
93-
'GC': 'g',
94-
'CG': 'b'}
95-
96-
label_positions = {
97-
'AT': 'center',
98-
'TA': 'center',
99-
'GA': 'center',
100-
'AG': 'center',
101-
'CA': 'left',
102-
'AC': 'left',
103-
'GT': 'left',
104-
'TG': 'left',
105-
'CT': 'center',
106-
'TC': 'center',
107-
'GC': 'center',
108-
'CG': 'center'}
109-
110-
def do_fontsize(k):
111-
return float(np.clip(max_text_size * np.sqrt(data[k]),
112-
min_text_size, max_text_size))
113-
114-
plt.text(0, 1, '$A_3$', color='r', size=do_fontsize('A'), **text_params)
115-
plt.text(1, 1, '$T_3$', color='k', size=do_fontsize('T'), **text_params)
116-
plt.text(0, 0, '$G_3$', color='g', size=do_fontsize('G'), **text_params)
117-
plt.text(1, 0, '$C_3$', color='b', size=do_fontsize('C'), **text_params)
52+
bases = 'ATGC'
53+
coords = {
54+
'A': np.array([0, 1]),
55+
'T': np.array([1, 1]),
56+
'G': np.array([0, 0]),
57+
'C': np.array([1, 0]),
58+
}
59+
colors = {'A': 'r', 'T': 'k', 'G': 'g', 'C': 'b'}
60+
61+
for base in bases:
62+
fontsize = np.clip(max_text_size * data[base]**(1/2),
63+
min_text_size, max_text_size)
64+
ax.text(*coords[base], f'${base}_3$',
65+
color=colors[base], size=fontsize,
66+
horizontalalignment='center', verticalalignment='center',
67+
weight='bold')
11868

11969
arrow_h_offset = 0.25 # data coordinates, empirically determined
12070
max_arrow_length = 1 - 2 * arrow_h_offset
12171
max_head_width = 2.5 * max_arrow_width
12272
max_head_length = 2 * max_arrow_width
123-
arrow_params = {'length_includes_head': True, 'shape': shape,
124-
'head_starts_at_zero': head_starts_at_zero}
12573
sf = 0.6 # max arrow size represents this in data coords
12674

127-
d = (r2 / 2 + arrow_h_offset - 0.5) / r2 # distance for diags
128-
r2v = arrow_sep / r2 # offset for diags
129-
130-
# tuple of x, y for start position
131-
positions = {
132-
'AT': (arrow_h_offset, 1 + arrow_sep),
133-
'TA': (1 - arrow_h_offset, 1 - arrow_sep),
134-
'GA': (-arrow_sep, arrow_h_offset),
135-
'AG': (arrow_sep, 1 - arrow_h_offset),
136-
'CA': (1 - d - r2v, d - r2v),
137-
'AC': (d + r2v, 1 - d + r2v),
138-
'GT': (d - r2v, d + r2v),
139-
'TG': (1 - d + r2v, 1 - d - r2v),
140-
'CT': (1 - arrow_sep, arrow_h_offset),
141-
'TC': (1 + arrow_sep, 1 - arrow_h_offset),
142-
'GC': (arrow_h_offset, arrow_sep),
143-
'CG': (1 - arrow_h_offset, -arrow_sep)}
144-
14575
if normalize_data:
14676
# find maximum value for rates, i.e. where keys are 2 chars long
14777
max_val = max((v for k, v in data.items() if len(k) == 2), default=0)
14878
# divide rates by max val, multiply by arrow scale factor
14979
for k, v in data.items():
15080
data[k] = v / max_val * sf
15181

152-
def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
82+
# iterate over strings 'AT', 'TA', 'AG', 'GA', etc.
83+
for pair in map(''.join, itertools.permutations(bases, 2)):
15384
# set the length of the arrow
15485
if display == 'length':
15586
length = (max_head_length
@@ -159,7 +90,6 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
15990
# set the transparency of the arrow
16091
if display == 'alpha':
16192
alpha = min(data[pair] / sf, alpha)
162-
16393
# set the width of the arrow
16494
if display == 'width':
16595
scale = data[pair] / sf
@@ -171,137 +101,59 @@ def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
171101
head_width = max_head_width
172102
head_length = max_head_length
173103

174-
fc = colors[pair]
175-
ec = ec or fc
176-
177-
x_scale, y_scale = deltas[pair]
178-
x_pos, y_pos = positions[pair]
179-
plt.arrow(x_pos, y_pos, x_scale * length, y_scale * length,
180-
fc=fc, ec=ec, alpha=alpha, width=width,
181-
head_width=head_width, head_length=head_length,
182-
**arrow_params)
183-
184-
# figure out coordinates for text
104+
fc = colors[pair[0]]
105+
106+
cp0 = coords[pair[0]]
107+
cp1 = coords[pair[1]]
108+
# unit vector in arrow direction
109+
delta = cos, sin = (cp1 - cp0) / np.hypot(*(cp1 - cp0))
110+
x_pos, y_pos = (
111+
(cp0 + cp1) / 2 # midpoint
112+
- delta * length / 2 # half the arrow length
113+
+ np.array([-sin, cos]) * arrow_sep # shift outwards by arrow_sep
114+
)
115+
ax.arrow(
116+
x_pos, y_pos, cos * length, sin * length,
117+
fc=fc, ec=ec or fc, alpha=alpha, width=width,
118+
head_width=head_width, head_length=head_length, shape=shape,
119+
length_includes_head=True,
120+
)
121+
122+
# figure out coordinates for text:
185123
# if drawing relative to base: x and y are same as for arrow
186124
# dx and dy are one arrow width left and up
187-
# need to rotate based on direction of arrow, use x_scale and y_scale
188-
# as sin x and cos x?
189-
sx, cx = y_scale, x_scale
190-
191-
where = label_positions[pair]
192-
if where == 'left':
193-
orig_position = 3 * np.array([[max_arrow_width, max_arrow_width]])
194-
elif where == 'absolute':
195-
orig_position = np.array([[max_arrow_length / 2.0,
196-
3 * max_arrow_width]])
197-
elif where == 'right':
198-
orig_position = np.array([[length - 3 * max_arrow_width,
199-
3 * max_arrow_width]])
200-
elif where == 'center':
201-
orig_position = np.array([[length / 2.0, 3 * max_arrow_width]])
202-
else:
203-
raise ValueError("Got unknown position parameter %s" % where)
204-
205-
M = np.array([[cx, sx], [-sx, cx]])
206-
coords = np.dot(orig_position, M) + [[x_pos, y_pos]]
207-
x, y = np.ravel(coords)
208-
orig_label = rate_labels[pair]
209-
label = r'$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:])
210-
211-
plt.text(x, y, label, size=label_text_size, ha='center', va='center',
212-
color=labelcolor or fc)
213-
214-
for p in sorted(positions):
215-
draw_arrow(p)
216-
217-
218-
# test data
219-
all_on_max = dict([(i, 1) for i in 'TCAG'] +
220-
[(i + j, 0.6) for i in 'TCAG' for j in 'TCAG'])
221-
222-
realistic_data = {
223-
'A': 0.4,
224-
'T': 0.3,
225-
'G': 0.5,
226-
'C': 0.2,
227-
'AT': 0.4,
228-
'AC': 0.3,
229-
'AG': 0.2,
230-
'TA': 0.2,
231-
'TC': 0.3,
232-
'TG': 0.4,
233-
'CT': 0.2,
234-
'CG': 0.3,
235-
'CA': 0.2,
236-
'GA': 0.1,
237-
'GT': 0.4,
238-
'GC': 0.1}
239-
240-
extreme_data = {
241-
'A': 0.75,
242-
'T': 0.10,
243-
'G': 0.10,
244-
'C': 0.05,
245-
'AT': 0.6,
246-
'AC': 0.3,
247-
'AG': 0.1,
248-
'TA': 0.02,
249-
'TC': 0.3,
250-
'TG': 0.01,
251-
'CT': 0.2,
252-
'CG': 0.5,
253-
'CA': 0.2,
254-
'GA': 0.1,
255-
'GT': 0.4,
256-
'GC': 0.2}
257-
258-
sample_data = {
259-
'A': 0.2137,
260-
'T': 0.3541,
261-
'G': 0.1946,
262-
'C': 0.2376,
263-
'AT': 0.0228,
264-
'AC': 0.0684,
265-
'AG': 0.2056,
266-
'TA': 0.0315,
267-
'TC': 0.0629,
268-
'TG': 0.0315,
269-
'CT': 0.1355,
270-
'CG': 0.0401,
271-
'CA': 0.0703,
272-
'GA': 0.1824,
273-
'GT': 0.0387,
274-
'GC': 0.1106}
125+
orig_positions = {
126+
'base': [3 * max_arrow_width, 3 * max_arrow_width],
127+
'center': [length / 2, 3 * max_arrow_width],
128+
'tip': [length - 3 * max_arrow_width, 3 * max_arrow_width],
129+
}
130+
# for diagonal arrows, put the label at the arrow base
131+
# for vertical or horizontal arrows, center the label
132+
where = 'base' if (cp0 != cp1).all() else 'center'
133+
# rotate based on direction of arrow (cos, sin)
134+
M = [[cos, -sin], [sin, cos]]
135+
x, y = np.dot(M, orig_positions[where]) + [x_pos, y_pos]
136+
label = r'$r_{_{\mathrm{%s}}}$' % (pair,)
137+
ax.text(x, y, label, size=label_text_size, ha='center', va='center',
138+
color=labelcolor or fc)
275139

276140

277141
if __name__ == '__main__':
278-
from sys import argv
279-
d = None
280-
if len(argv) > 1:
281-
if argv[1] == 'full':
282-
d = all_on_max
283-
scaled = False
284-
elif argv[1] == 'extreme':
285-
d = extreme_data
286-
scaled = False
287-
elif argv[1] == 'realistic':
288-
d = realistic_data
289-
scaled = False
290-
elif argv[1] == 'sample':
291-
d = sample_data
292-
scaled = True
293-
if d is None:
294-
d = all_on_max
295-
scaled = False
296-
if len(argv) > 2:
297-
display = argv[2]
298-
else:
299-
display = 'length'
142+
data = { # test data
143+
'A': 0.4, 'T': 0.3, 'G': 0.6, 'C': 0.2,
144+
'AT': 0.4, 'AC': 0.3, 'AG': 0.2,
145+
'TA': 0.2, 'TC': 0.3, 'TG': 0.4,
146+
'CT': 0.2, 'CG': 0.3, 'CA': 0.2,
147+
'GA': 0.1, 'GT': 0.4, 'GC': 0.1,
148+
}
300149

301150
size = 4
302-
plt.figure(figsize=(size, size))
151+
fig = plt.figure(figsize=(3 * size, size), constrained_layout=True)
152+
axs = fig.subplot_mosaic([["length", "width", "alpha"]])
303153

304-
make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None,
305-
normalize_data=scaled, head_starts_at_zero=True, size=size)
154+
for display, ax in axs.items():
155+
make_arrow_graph(
156+
ax, data, display=display, linewidth=0.001, edgecolor=None,
157+
normalize_data=True, size=size)
306158

307159
plt.show()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.