Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5830055

Browse filesBrowse files
Merge pull request plotly#1875 from plotly/px_real_template
PX shouldn't modify attrs controlled by template
2 parents 06a2cb9 + 1e88daa commit 5830055
Copy full SHA for 5830055

File tree

Expand file treeCollapse file tree

2 files changed

+147
-40
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+147
-40
lines changed

‎packages/python/plotly/plotly/express/_core.py

Copy file name to clipboardExpand all lines: packages/python/plotly/plotly/express/_core.py
+40-40Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,18 @@ def configure_cartesian_marginal_axes(args, fig, orders):
375375

376376
# Configure axis ticks on marginal subplots
377377
if args["marginal_x"]:
378-
fig.update_yaxes(
379-
showticklabels=False, showgrid=args["marginal_x"] == "histogram", row=nrows
380-
)
381-
fig.update_xaxes(showgrid=True, row=nrows)
378+
fig.update_yaxes(showticklabels=False, row=nrows)
379+
if args["template"].layout.yaxis.showgrid is None:
380+
fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
381+
if args["template"].layout.xaxis.showgrid is None:
382+
fig.update_xaxes(showgrid=True, row=nrows)
382383

383384
if args["marginal_y"]:
384-
fig.update_xaxes(
385-
showticklabels=False, showgrid=args["marginal_y"] == "histogram", col=ncols
386-
)
387-
fig.update_yaxes(showgrid=True, col=ncols)
385+
fig.update_xaxes(showticklabels=False, col=ncols)
386+
if args["template"].layout.xaxis.showgrid is None:
387+
fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
388+
if args["template"].layout.yaxis.showgrid is None:
389+
fig.update_yaxes(showgrid=True, col=ncols)
388390

389391
# Add axis titles to non-marginal subplots
390392
y_title = get_decorated_label(args, args["y"], "y")
@@ -687,55 +689,47 @@ def apply_default_cascade(args):
687689
else:
688690
args["template"] = "plotly"
689691

690-
# retrieve the actual template if we were given a name
691692
try:
692-
template = pio.templates[args["template"]]
693+
# retrieve the actual template if we were given a name
694+
args["template"] = pio.templates[args["template"]]
693695
except Exception:
694-
template = args["template"]
696+
# otherwise try to build a real template
697+
args["template"] = go.layout.Template(args["template"])
695698

696699
# if colors not set explicitly or in px.defaults, defer to a template
697700
# if the template doesn't have one, we set some final fallback defaults
698701
if "color_continuous_scale" in args:
699-
if args["color_continuous_scale"] is None:
700-
try:
701-
args["color_continuous_scale"] = [
702-
x[1] for x in template.layout.colorscale.sequential
703-
]
704-
except (AttributeError, TypeError):
705-
pass
702+
if (
703+
args["color_continuous_scale"] is None
704+
and args["template"].layout.colorscale.sequential
705+
):
706+
args["color_continuous_scale"] = [
707+
x[1] for x in args["template"].layout.colorscale.sequential
708+
]
706709
if args["color_continuous_scale"] is None:
707710
args["color_continuous_scale"] = sequential.Viridis
708711

709712
if "color_discrete_sequence" in args:
710-
if args["color_discrete_sequence"] is None:
711-
try:
712-
args["color_discrete_sequence"] = template.layout.colorway
713-
except (AttributeError, TypeError):
714-
pass
713+
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
714+
args["color_discrete_sequence"] = args["template"].layout.colorway
715715
if args["color_discrete_sequence"] is None:
716716
args["color_discrete_sequence"] = qualitative.D3
717717

718718
# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
719719
# see if we can defer to template. If not, set reasonable defaults
720720
if "symbol_sequence" in args:
721-
if args["symbol_sequence"] is None:
722-
try:
723-
args["symbol_sequence"] = [
724-
scatter.marker.symbol for scatter in template.data.scatter
725-
]
726-
except (AttributeError, TypeError):
727-
pass
721+
if args["symbol_sequence"] is None and args["template"].data.scatter:
722+
args["symbol_sequence"] = [
723+
scatter.marker.symbol for scatter in args["template"].data.scatter
724+
]
728725
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
729726
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]
730727

731728
if "line_dash_sequence" in args:
732-
if args["line_dash_sequence"] is None:
733-
try:
734-
args["line_dash_sequence"] = [
735-
scatter.line.dash for scatter in template.data.scatter
736-
]
737-
except (AttributeError, TypeError):
738-
pass
729+
if args["line_dash_sequence"] is None and args["template"].data.scatter:
730+
args["line_dash_sequence"] = [
731+
scatter.line.dash for scatter in args["template"].data.scatter
732+
]
739733
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
740734
args["line_dash_sequence"] = [
741735
"solid",
@@ -1264,13 +1258,17 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12641258
cmax=range_color[1],
12651259
colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)),
12661260
)
1267-
for v in ["title", "height", "width", "template"]:
1261+
for v in ["title", "height", "width"]:
12681262
if args[v]:
12691263
layout_patch[v] = args[v]
12701264
layout_patch["legend"] = {"tracegroupgap": 0}
1271-
if "title" not in layout_patch:
1265+
if "title" not in layout_patch and args["template"].layout.margin.t is None:
12721266
layout_patch["margin"] = {"t": 60}
1273-
if "size" in args and args["size"]:
1267+
if (
1268+
"size" in args
1269+
and args["size"]
1270+
and args["template"].layout.legend.itemsizing is None
1271+
):
12741272
layout_patch["legend"]["itemsizing"] = "constant"
12751273

12761274
fig = init_figure(
@@ -1295,6 +1293,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12951293
# Add traces, layout and frames to figure
12961294
fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
12971295
fig.layout.update(layout_patch)
1296+
if "template" in args and args["template"] is not None:
1297+
fig.update_layout(template=args["template"], overwrite=True)
12981298
fig.frames = frame_list if len(frames) > 1 else []
12991299

13001300
fig._px_trendlines = pd.DataFrame(trendline_rows)

‎packages/python/plotly/plotly/tests/test_core/test_px/test_px.py

Copy file name to clipboardExpand all lines: packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
+107Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,110 @@ def test_custom_data_scatter():
5151
fig.data[0].hovertemplate
5252
== "sepal_width=%{x}<br>sepal_length=%{y}<br>petal_length=%{customdata[2]}<br>petal_width=%{customdata[3]}<br>species_id=%{customdata[0]}"
5353
)
54+
55+
56+
def test_px_templates():
57+
import plotly.io as pio
58+
import plotly.graph_objects as go
59+
60+
tips = px.data.tips()
61+
62+
# use the normal defaults
63+
fig = px.scatter()
64+
assert fig.layout.template == pio.templates[pio.templates.default]
65+
66+
# respect changes to defaults
67+
pio.templates.default = "seaborn"
68+
fig = px.scatter()
69+
assert fig.layout.template == pio.templates["seaborn"]
70+
71+
# special px-level defaults over pio defaults
72+
pio.templates.default = "seaborn"
73+
px.defaults.template = "ggplot2"
74+
fig = px.scatter()
75+
assert fig.layout.template == pio.templates["ggplot2"]
76+
77+
# accept names in args over pio and px defaults
78+
fig = px.scatter(template="seaborn")
79+
assert fig.layout.template == pio.templates["seaborn"]
80+
81+
# accept objects in args
82+
fig = px.scatter(template={})
83+
assert fig.layout.template == go.layout.Template()
84+
85+
# read colorway from the template
86+
fig = px.scatter(
87+
tips,
88+
x="total_bill",
89+
y="tip",
90+
color="sex",
91+
template=dict(layout_colorway=["red", "blue"]),
92+
)
93+
assert fig.data[0].marker.color == "red"
94+
assert fig.data[1].marker.color == "blue"
95+
96+
# default colorway fallback
97+
fig = px.scatter(tips, x="total_bill", y="tip", color="sex", template=dict())
98+
assert fig.data[0].marker.color == px.colors.qualitative.D3[0]
99+
assert fig.data[1].marker.color == px.colors.qualitative.D3[1]
100+
101+
# pio default template colorway fallback
102+
pio.templates.default = "seaborn"
103+
px.defaults.template = None
104+
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
105+
assert fig.data[0].marker.color == pio.templates["seaborn"].layout.colorway[0]
106+
assert fig.data[1].marker.color == pio.templates["seaborn"].layout.colorway[1]
107+
108+
# pio default template colorway fallback
109+
pio.templates.default = "seaborn"
110+
px.defaults.template = "ggplot2"
111+
fig = px.scatter(tips, x="total_bill", y="tip", color="sex")
112+
assert fig.data[0].marker.color == pio.templates["ggplot2"].layout.colorway[0]
113+
assert fig.data[1].marker.color == pio.templates["ggplot2"].layout.colorway[1]
114+
115+
# don't overwrite top margin when set in template
116+
fig = px.scatter(title="yo")
117+
assert fig.layout.margin.t is None
118+
119+
fig = px.scatter()
120+
assert fig.layout.margin.t == 60
121+
122+
fig = px.scatter(template=dict(layout_margin_t=2))
123+
assert fig.layout.margin.t is None
124+
125+
# don't force histogram gridlines when set in template
126+
pio.templates.default = "none"
127+
px.defaults.template = None
128+
fig = px.scatter(
129+
tips, x="total_bill", y="tip", marginal_x="histogram", marginal_y="histogram"
130+
)
131+
assert fig.layout.xaxis2.showgrid
132+
assert fig.layout.xaxis3.showgrid
133+
assert fig.layout.yaxis2.showgrid
134+
assert fig.layout.yaxis3.showgrid
135+
136+
fig = px.scatter(
137+
tips,
138+
x="total_bill",
139+
y="tip",
140+
marginal_x="histogram",
141+
marginal_y="histogram",
142+
template=dict(layout_yaxis_showgrid=False),
143+
)
144+
assert fig.layout.xaxis2.showgrid
145+
assert fig.layout.xaxis3.showgrid
146+
assert fig.layout.yaxis2.showgrid is None
147+
assert fig.layout.yaxis3.showgrid is None
148+
149+
fig = px.scatter(
150+
tips,
151+
x="total_bill",
152+
y="tip",
153+
marginal_x="histogram",
154+
marginal_y="histogram",
155+
template=dict(layout_xaxis_showgrid=False),
156+
)
157+
assert fig.layout.xaxis2.showgrid is None
158+
assert fig.layout.xaxis3.showgrid is None
159+
assert fig.layout.yaxis2.showgrid
160+
assert fig.layout.yaxis3.showgrid

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.