Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 049d837

Browse filesBrowse files
Merge pull request plotly#1838 from plotly/facet_wrap
initial build-out of facet wrapping
2 parents be1a182 + ee48cca commit 049d837
Copy full SHA for 049d837

File tree

Expand file treeCollapse file tree

5 files changed

+93
-33
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+93
-33
lines changed

‎packages/python/plotly/plotly/express/_chart_types.py

Copy file name to clipboardExpand all lines: packages/python/plotly/plotly/express/_chart_types.py
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def scatter(
1616
text=None,
1717
facet_row=None,
1818
facet_col=None,
19+
facet_col_wrap=0,
1920
error_x=None,
2021
error_x_minus=None,
2122
error_y=None,
@@ -65,6 +66,7 @@ def density_contour(
6566
color=None,
6667
facet_row=None,
6768
facet_col=None,
69+
facet_col_wrap=0,
6870
hover_name=None,
6971
hover_data=None,
7072
animation_frame=None,
@@ -120,6 +122,7 @@ def density_heatmap(
120122
z=None,
121123
facet_row=None,
122124
facet_col=None,
125+
facet_col_wrap=0,
123126
hover_name=None,
124127
hover_data=None,
125128
animation_frame=None,
@@ -180,6 +183,7 @@ def line(
180183
text=None,
181184
facet_row=None,
182185
facet_col=None,
186+
facet_col_wrap=0,
183187
error_x=None,
184188
error_x_minus=None,
185189
error_y=None,
@@ -225,6 +229,7 @@ def area(
225229
text=None,
226230
facet_row=None,
227231
facet_col=None,
232+
facet_col_wrap=0,
228233
animation_frame=None,
229234
animation_group=None,
230235
category_orders={},
@@ -267,6 +272,7 @@ def bar(
267272
color=None,
268273
facet_row=None,
269274
facet_col=None,
275+
facet_col_wrap=0,
270276
hover_name=None,
271277
hover_data=None,
272278
custom_data=None,
@@ -318,6 +324,7 @@ def histogram(
318324
color=None,
319325
facet_row=None,
320326
facet_col=None,
327+
facet_col_wrap=0,
321328
hover_name=None,
322329
hover_data=None,
323330
animation_frame=None,
@@ -376,6 +383,7 @@ def violin(
376383
color=None,
377384
facet_row=None,
378385
facet_col=None,
386+
facet_col_wrap=0,
379387
hover_name=None,
380388
hover_data=None,
381389
custom_data=None,
@@ -427,6 +435,7 @@ def box(
427435
color=None,
428436
facet_row=None,
429437
facet_col=None,
438+
facet_col_wrap=0,
430439
hover_name=None,
431440
hover_data=None,
432441
custom_data=None,
@@ -473,6 +482,7 @@ def strip(
473482
color=None,
474483
facet_row=None,
475484
facet_col=None,
485+
facet_col_wrap=0,
476486
hover_name=None,
477487
hover_data=None,
478488
custom_data=None,

‎packages/python/plotly/plotly/express/_core.py

Copy file name to clipboardExpand all lines: packages/python/plotly/plotly/express/_core.py
+47-33Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
233233
result["y"] = trendline[:, 1]
234234
hover_header = "<b>LOWESS trendline</b><br><br>"
235235
elif v == "ols":
236-
fit_results = sm.OLS(y, sm.add_constant(x)).fit()
236+
fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
237237
result["y"] = fit_results.predict()
238238
hover_header = "<b>OLS trendline</b><br>"
239239
hover_header += "%s = %f * %s + %f<br>" % (
@@ -747,10 +747,10 @@ def apply_default_cascade(args):
747747
]
748748

749749
# If both marginals and faceting are specified, faceting wins
750-
if args.get("facet_col", None) and args.get("marginal_y", None):
750+
if args.get("facet_col", None) is not None and args.get("marginal_y", None):
751751
args["marginal_y"] = None
752752

753-
if args.get("facet_row", None) and args.get("marginal_x", None):
753+
if args.get("facet_row", None) is not None and args.get("marginal_x", None):
754754
args["marginal_x"] = None
755755

756756

@@ -874,7 +874,7 @@ def build_dataframe(args, attrables, array_attrables):
874874
"pandas MultiIndex is not supported by plotly express "
875875
"at the moment." % field
876876
)
877-
## ----------------- argument is a col name ----------------------
877+
# ----------------- argument is a col name ----------------------
878878
if isinstance(argument, str) or isinstance(
879879
argument, int
880880
): # just a column name given as str or int
@@ -1042,6 +1042,13 @@ def infer_config(args, constructor, trace_patch):
10421042
args[position] = args["marginal"]
10431043
args[other_position] = None
10441044

1045+
if (
1046+
args.get("marginal_x", None) is not None
1047+
or args.get("marginal_y", None) is not None
1048+
or args.get("facet_row", None) is not None
1049+
):
1050+
args["facet_col_wrap"] = 0
1051+
10451052
# Compute applicable grouping attributes
10461053
for k in group_attrables:
10471054
if k in args:
@@ -1098,15 +1105,14 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
10981105

10991106
orders, sorted_group_names = get_orderings(args, grouper, grouped)
11001107

1101-
has_marginal_x = bool(args.get("marginal_x", False))
1102-
has_marginal_y = bool(args.get("marginal_y", False))
1103-
11041108
subplot_type = _subplot_type_for_trace_type(constructor().type)
11051109

11061110
trace_names_by_frame = {}
11071111
frames = OrderedDict()
11081112
trendline_rows = []
11091113
nrows = ncols = 1
1114+
col_labels = []
1115+
row_labels = []
11101116
for group_name in sorted_group_names:
11111117
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
11121118
mapping_labels = OrderedDict()
@@ -1188,27 +1194,36 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
11881194
# Find row for trace, handling facet_row and marginal_x
11891195
if m.facet == "row":
11901196
row = m.val_map[val]
1191-
trace._subplot_row_val = val
1197+
if args["facet_row"] and len(row_labels) < row:
1198+
row_labels.append(args["facet_row"] + "=" + str(val))
11921199
else:
1193-
if has_marginal_x and trace_spec.marginal != "x":
1200+
if (
1201+
bool(args.get("marginal_x", False))
1202+
and trace_spec.marginal != "x"
1203+
):
11941204
row = 2
11951205
else:
11961206
row = 1
11971207

1198-
nrows = max(nrows, row)
1199-
if row > 1:
1200-
trace._subplot_row = row
1201-
1208+
facet_col_wrap = args.get("facet_col_wrap", 0)
12021209
# Find col for trace, handling facet_col and marginal_y
12031210
if m.facet == "col":
12041211
col = m.val_map[val]
1205-
trace._subplot_col_val = val
1212+
if args["facet_col"] and len(col_labels) < col:
1213+
col_labels.append(args["facet_col"] + "=" + str(val))
1214+
if facet_col_wrap: # assumes no facet_row, no marginals
1215+
row = 1 + ((col - 1) // facet_col_wrap)
1216+
col = 1 + ((col - 1) % facet_col_wrap)
12061217
else:
12071218
if trace_spec.marginal == "y":
12081219
col = 2
12091220
else:
12101221
col = 1
12111222

1223+
nrows = max(nrows, row)
1224+
if row > 1:
1225+
trace._subplot_row = row
1226+
12121227
ncols = max(ncols, col)
12131228
if col > 1:
12141229
trace._subplot_col = col
@@ -1238,7 +1253,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12381253
if show_colorbar:
12391254
colorvar = "z" if constructor == go.Histogram2d else "color"
12401255
range_color = args["range_color"] or [None, None]
1241-
d = len(args["color_continuous_scale"]) - 1
12421256

12431257
colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
12441258
layout_patch["coloraxis1"] = dict(
@@ -1260,7 +1274,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12601274
layout_patch["legend"]["itemsizing"] = "constant"
12611275

12621276
fig = init_figure(
1263-
args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
1277+
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
12641278
)
12651279

12661280
# Position traces in subplots
@@ -1290,49 +1304,39 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12901304
return fig
12911305

12921306

1293-
def init_figure(
1294-
args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
1295-
):
1307+
def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
12961308
# Build subplot specs
12971309
specs = [[{}] * ncols for _ in range(nrows)]
1298-
column_titles = [None] * ncols
1299-
row_titles = [None] * nrows
13001310
for frame in frame_list:
13011311
for trace in frame["data"]:
13021312
row0 = trace._subplot_row - 1
13031313
col0 = trace._subplot_col - 1
1304-
13051314
if isinstance(trace, go.Splom):
13061315
# Splom not compatible with make_subplots, treat as domain
13071316
specs[row0][col0] = {"type": "domain"}
13081317
else:
13091318
specs[row0][col0] = {"type": trace.type}
1310-
if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"):
1311-
row_titles[row0] = args["facet_row"] + "=" + str(trace._subplot_row_val)
1312-
1313-
if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"):
1314-
column_titles[col0] = (
1315-
args["facet_col"] + "=" + str(trace._subplot_col_val)
1316-
)
13171319

13181320
# Default row/column widths uniform
13191321
column_widths = [1.0] * ncols
13201322
row_heights = [1.0] * nrows
13211323

13221324
# Build column_widths/row_heights
13231325
if subplot_type == "xy":
1324-
if has_marginal_x:
1326+
if bool(args.get("marginal_x", False)):
13251327
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
13261328
main_size = 0.74
13271329
else:
13281330
main_size = 0.84
13291331

13301332
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
13311333
vertical_spacing = 0.01
1334+
elif args.get("facet_col_wrap", 0):
1335+
vertical_spacing = 0.07
13321336
else:
13331337
vertical_spacing = 0.03
13341338

1335-
if has_marginal_y:
1339+
if bool(args.get("marginal_y", False)):
13361340
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
13371341
main_size = 0.74
13381342
else:
@@ -1351,15 +1355,25 @@ def init_figure(
13511355
vertical_spacing = 0.1
13521356
horizontal_spacing = 0.1
13531357

1358+
facet_col_wrap = args.get("facet_col_wrap", 0)
1359+
if facet_col_wrap:
1360+
subplot_labels = [None] * nrows * ncols
1361+
while len(col_labels) < nrows * ncols:
1362+
col_labels.append(None)
1363+
for i in range(nrows):
1364+
for j in range(ncols):
1365+
subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]
1366+
13541367
# Create figure with subplots
13551368
fig = make_subplots(
13561369
rows=nrows,
13571370
cols=ncols,
13581371
specs=specs,
13591372
shared_xaxes="all",
13601373
shared_yaxes="all",
1361-
row_titles=list(reversed(row_titles)),
1362-
column_titles=column_titles,
1374+
row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
1375+
column_titles=[] if facet_col_wrap else col_labels,
1376+
subplot_titles=subplot_labels if facet_col_wrap else [],
13631377
horizontal_spacing=horizontal_spacing,
13641378
vertical_spacing=vertical_spacing,
13651379
row_heights=row_heights,

‎packages/python/plotly/plotly/express/_doc.py

Copy file name to clipboardExpand all lines: packages/python/plotly/plotly/express/_doc.py
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@
183183
colref_desc,
184184
"Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.",
185185
],
186+
facet_col_wrap=[
187+
"int",
188+
"Maximum number of facet columns.",
189+
"Wraps the column variable at this width, so that the column facets span multiple rows.",
190+
"Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.",
191+
],
186192
animation_frame=[
187193
colref_type,
188194
colref_desc,

‎packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

Copy file name to clipboardExpand all lines: packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def test_pandas_series():
6161
assert fig.data[0].hovertemplate == "day=%{x}<br>y=%{y}"
6262
fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"})
6363
assert fig.data[0].hovertemplate == "day=%{x}<br>bill=%{y}"
64+
# lock down that we can pass df.col to facet_*
65+
fig = px.bar(tips, x="day", y="tip", facet_row=tips.day, facet_col=tips.day)
66+
assert fig.data[0].hovertemplate == "day=%{x}<br>tip=%{y}"
6467

6568

6669
def test_several_dataframes():

‎test/percy/plotly-express.py

Copy file name to clipboardExpand all lines: test/percy/plotly-express.py
+27Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,33 @@
184184

185185
import plotly.express as px
186186

187+
tips = px.data.tips()
188+
fig = px.scatter(
189+
tips,
190+
x="day",
191+
y="tip",
192+
facet_col="day",
193+
facet_col_wrap=2,
194+
category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
195+
)
196+
fig.write_html(os.path.join(dir_name, "facet_wrap_neat.html"))
197+
198+
import plotly.express as px
199+
200+
tips = px.data.tips()
201+
fig = px.scatter(
202+
tips,
203+
x="day",
204+
y="tip",
205+
color="sex",
206+
facet_col="day",
207+
facet_col_wrap=3,
208+
category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
209+
)
210+
fig.write_html(os.path.join(dir_name, "facet_wrap_ragged.html"))
211+
212+
import plotly.express as px
213+
187214
gapminder = px.data.gapminder()
188215
fig = px.area(gapminder, x="year", y="pop", color="continent", line_group="country")
189216
fig.write_html(os.path.join(dir_name, "area.html"))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.