Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 66da888

Browse filesBrowse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Make pt2 inference private API use training IR (#153972)
Summary: #buildall Test Plan: CI Differential Revision: D74582970
1 parent cad0727 commit 66da888
Copy full SHA for 66da888

File tree

Expand file treeCollapse file tree

2 files changed

+180
-2
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+180
-2
lines changed

‎test/export/test_export.py

Copy file name to clipboardExpand all lines: test/export/test_export.py
+103Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6501,6 +6501,109 @@ def forward(self, x):
65016501

65026502
self.assertEqual(ep.module()(*inputs), model(*inputs))
65036503

6504+
def test_export_aten_to_unflatten(self):
6505+
class Bar(torch.nn.Module):
6506+
def __init__(self):
6507+
super().__init__()
6508+
6509+
def forward(self, x):
6510+
return x.sum()
6511+
6512+
class Foo(torch.nn.Module):
6513+
def __init__(self):
6514+
super().__init__()
6515+
self.bar = Bar()
6516+
6517+
def forward(self, x):
6518+
to = x.to(torch.float)
6519+
return self.bar(to).sum()
6520+
6521+
inp = torch.randn(4, 4)
6522+
6523+
ep = export(
6524+
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
6525+
)
6526+
mod = ep.module()
6527+
self.assertTrue(torch.allclose(mod(inp), Foo()(inp)))
6528+
6529+
@testing.expectedFailureLegacyExportNonStrict
6530+
@testing.expectedFailureLegacyExportStrict
6531+
@testing.expectedFailureRetraceabilityNonStrict # when we retrace, ep.module() is hierarchical
6532+
@testing.expectedFailureRetraceability # when we retrace, ep.module() is hierarchical
6533+
def test_export_aten_to_unflatten_subclass(self):
6534+
class Bar(torch.nn.Module):
6535+
def __init__(self):
6536+
super().__init__()
6537+
6538+
def forward(self, x):
6539+
return x.sum()
6540+
6541+
class Foo(torch.nn.Module):
6542+
def __init__(self):
6543+
super().__init__()
6544+
self.bar = Bar()
6545+
self.param = torch.nn.Parameter(
6546+
TwoTensor(torch.ones(4, 4), torch.ones(4, 4))
6547+
)
6548+
6549+
def forward(self, x):
6550+
to = self.param.to(torch.float)
6551+
return (self.bar(to).sum() + x.sum()).get_elem_a()
6552+
6553+
inp = torch.randn(4, 4)
6554+
6555+
with self.assertRaisesRegex(
6556+
ValueError, "It looks like p_param is a tensor subclass."
6557+
):
6558+
export(
6559+
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
6560+
).run_decompositions({})
6561+
6562+
def test_export_aten_to_unflatten_subclass_pre_dispatch(self):
6563+
class Bar(torch.nn.Module):
6564+
def __init__(self):
6565+
super().__init__()
6566+
6567+
def forward(self, x):
6568+
return x.sum()
6569+
6570+
class Foo(torch.nn.Module):
6571+
def __init__(self):
6572+
super().__init__()
6573+
self.bar = Bar()
6574+
self.param = torch.nn.Parameter(
6575+
TwoTensor(torch.ones(4, 4), torch.ones(4, 4))
6576+
)
6577+
6578+
def forward(self, x):
6579+
to = self.param.to(torch.float)
6580+
return (self.bar(to).sum() + x.sum()).get_elem_a()
6581+
6582+
inp = torch.randn(4, 4)
6583+
6584+
ep = export_for_training(
6585+
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
6586+
)
6587+
unflat = unflatten(ep).bar
6588+
self.assertExpectedInline(
6589+
str(unflat.graph).strip(),
6590+
"""\
6591+
graph():
6592+
%_positional_arg_0 : [num_users=1] = placeholder[target=_positional_arg_0]
6593+
%_spec_0 : [num_users=1] = get_attr[target=_spec_0]
6594+
%tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (((%_positional_arg_0,), {}), %_spec_0), kwargs = {})
6595+
%to : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {})
6596+
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%to,), kwargs = {})
6597+
%_spec_1 : [num_users=1] = get_attr[target=_spec_1]
6598+
%tree_unflatten : [num_users=1] = call_function[target=torch.utils._pytree.tree_unflatten](args = ((%sum_1,), %_spec_1), kwargs = {})
6599+
return tree_unflatten""",
6600+
)
6601+
6602+
with self.assertRaisesRegex(
6603+
ValueError, "It looks like p_param is a tensor subclass."
6604+
):
6605+
ep.run_decompositions()
6606+
65046607
def test_float_conversion(self):
65056608
class Module(torch.nn.Module):
65066609
def forward(self, x):

‎torch/export/exported_program.py

Copy file name to clipboardExpand all lines: torch/export/exported_program.py
+77-2Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import operator
88
import types
99
import warnings
10-
from collections import namedtuple
10+
from collections import defaultdict, namedtuple
1111
from collections.abc import Iterator
1212
from contextlib import contextmanager
1313
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
@@ -816,25 +816,97 @@ def _common_getitem_elimination_pass(
816816

817817

818818
def _get_updated_module_call_graph(
819+
old_gm: torch.fx.GraphModule,
820+
old_graph_signature: ExportGraphSignature,
819821
gm: torch.fx.GraphModule,
822+
graph_signature: ExportGraphSignature,
820823
old_module_call_graph: list[ModuleCallEntry],
821824
):
822825
new_module_call_graph = copy.deepcopy(old_module_call_graph)
823826

827+
old_nodes = {node.name: node for node in old_gm.graph.nodes}
828+
829+
old_graph_params_buffers = {
830+
**old_graph_signature.inputs_to_parameters,
831+
**old_graph_signature.inputs_to_buffers,
832+
}
833+
new_graph_params_buffers = {
834+
**graph_signature.inputs_to_parameters,
835+
**graph_signature.inputs_to_buffers,
836+
}
837+
824838
# use node-level provenance metadata to create a map
825839
# from old node names to new node names
826840
provenance: dict[str, str] = {}
827841
for node in gm.graph.nodes:
828842
if history := node.meta.get("from_node", []):
829843
provenance[history[-1].name] = node.name
830844

845+
# For params and buffers, we might have applied parameterizaiton rule
846+
# so that the names might have changed. But for user inputs, we know we
847+
# must preserve the old name.
848+
elif node.op == "placeholder":
849+
if not (
850+
node.name in new_graph_params_buffers
851+
or node.name in graph_signature.input_tokens
852+
):
853+
assert node.name in old_nodes
854+
provenance[node.name] = node.name
855+
856+
# For all the parameters and buffers, we first see
857+
# if they are result of paramerizaitons and if they
858+
# are, we log them and error later
859+
old_param_to_desugared = defaultdict(list)
860+
for name, target in new_graph_params_buffers.items():
861+
# if the parameters are not parametrized, the naming won't change.
862+
if not target.startswith("parametrizations."):
863+
assert name in old_graph_params_buffers
864+
provenance[name] = name
865+
else:
866+
old_target = ".".join(target.split(".")[1:-1])
867+
old_param_to_desugared[old_target].append(name)
868+
831869
# map old names to new names in module call signatures
832870
for entry in new_module_call_graph:
833871
signature = entry.signature
834872
if signature is None:
835873
continue
836874
for x in [*signature.inputs, *signature.outputs]:
837-
x.name = provenance.get(x.name, x.name)
875+
# We noticed that submodule is taking subclass as input. we can't
876+
# preserve signature here.
877+
if x.name in old_param_to_desugared:
878+
raise ValueError(
879+
f"It looks like {old_target} is a tensor subclass. "
880+
f"Preserving submodule that takes subclass parameter is not supported"
881+
f" in inference IR because we desugar them, resulting in more tensors"
882+
)
883+
884+
if x.name in provenance:
885+
x.name = provenance[x.name]
886+
887+
# This can happen when aten.to is called at graph boundaries.
888+
# Basically aten.to at post-dispatch level can either be copy
889+
# or alias. In the alias case, we will no-op it so it will
890+
# disappear from the graph. If we detect such case, we should
891+
# reuse the input to aten.to as the new input to the submodule.
892+
# Technically this can happen for other maybe aliasing ops,
893+
# but aten.to is probably the most common one.
894+
elif x.name in old_nodes:
895+
old_node = old_nodes[x.name]
896+
if old_node.op == "call_function" and old_node.target in [
897+
torch.ops.aten.to.dtype_layout,
898+
torch.ops.aten.to.device,
899+
torch.ops.aten.to.dtype,
900+
]:
901+
old_target = old_node.args[0].name
902+
if old_target not in provenance:
903+
raise ValueError(
904+
f"It looks like {old_target} is a tensor subclass. "
905+
f"Preserving submodule that takes subclass parameter is not supported"
906+
f" in inference IR because we desugar them, resulting in more tensors"
907+
)
908+
909+
x.name = provenance[old_target]
838910

839911
return new_module_call_graph
840912

@@ -864,7 +936,10 @@ def _decompose_exported_program(
864936
# new nodes due to decompositions. So we need to update these signatures
865937
# in the decomposed exported program's module_call_graph.
866938
new_module_call_graph = _get_updated_module_call_graph(
939+
ep.graph_module,
940+
ep.graph_signature,
867941
gm,
942+
new_graph_signature,
868943
ep.module_call_graph,
869944
)
870945

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.