7
7
import operator
8
8
import types
9
9
import warnings
10
- from collections import namedtuple
10
+ from collections import defaultdict , namedtuple
11
11
from collections .abc import Iterator
12
12
from contextlib import contextmanager
13
13
from typing import Any , Callable , final , Optional , TYPE_CHECKING , Union
@@ -816,25 +816,97 @@ def _common_getitem_elimination_pass(
816
816
817
817
818
818
def _get_updated_module_call_graph (
819
+ old_gm : torch .fx .GraphModule ,
820
+ old_graph_signature : ExportGraphSignature ,
819
821
gm : torch .fx .GraphModule ,
822
+ graph_signature : ExportGraphSignature ,
820
823
old_module_call_graph : list [ModuleCallEntry ],
821
824
):
822
825
new_module_call_graph = copy .deepcopy (old_module_call_graph )
823
826
827
+ old_nodes = {node .name : node for node in old_gm .graph .nodes }
828
+
829
+ old_graph_params_buffers = {
830
+ ** old_graph_signature .inputs_to_parameters ,
831
+ ** old_graph_signature .inputs_to_buffers ,
832
+ }
833
+ new_graph_params_buffers = {
834
+ ** graph_signature .inputs_to_parameters ,
835
+ ** graph_signature .inputs_to_buffers ,
836
+ }
837
+
824
838
# use node-level provenance metadata to create a map
825
839
# from old node names to new node names
826
840
provenance : dict [str , str ] = {}
827
841
for node in gm .graph .nodes :
828
842
if history := node .meta .get ("from_node" , []):
829
843
provenance [history [- 1 ].name ] = node .name
830
844
845
+ # For params and buffers, we might have applied parameterizaiton rule
846
+ # so that the names might have changed. But for user inputs, we know we
847
+ # must preserve the old name.
848
+ elif node .op == "placeholder" :
849
+ if not (
850
+ node .name in new_graph_params_buffers
851
+ or node .name in graph_signature .input_tokens
852
+ ):
853
+ assert node .name in old_nodes
854
+ provenance [node .name ] = node .name
855
+
856
+ # For all the parameters and buffers, we first see
857
+ # if they are result of paramerizaitons and if they
858
+ # are, we log them and error later
859
+ old_param_to_desugared = defaultdict (list )
860
+ for name , target in new_graph_params_buffers .items ():
861
+ # if the parameters are not parametrized, the naming won't change.
862
+ if not target .startswith ("parametrizations." ):
863
+ assert name in old_graph_params_buffers
864
+ provenance [name ] = name
865
+ else :
866
+ old_target = "." .join (target .split ("." )[1 :- 1 ])
867
+ old_param_to_desugared [old_target ].append (name )
868
+
831
869
# map old names to new names in module call signatures
832
870
for entry in new_module_call_graph :
833
871
signature = entry .signature
834
872
if signature is None :
835
873
continue
836
874
for x in [* signature .inputs , * signature .outputs ]:
837
- x .name = provenance .get (x .name , x .name )
875
+ # We noticed that submodule is taking subclass as input. we can't
876
+ # preserve signature here.
877
+ if x .name in old_param_to_desugared :
878
+ raise ValueError (
879
+ f"It looks like { old_target } is a tensor subclass. "
880
+ f"Preserving submodule that takes subclass parameter is not supported"
881
+ f" in inference IR because we desugar them, resulting in more tensors"
882
+ )
883
+
884
+ if x .name in provenance :
885
+ x .name = provenance [x .name ]
886
+
887
+ # This can happen when aten.to is called at graph boundaries.
888
+ # Basically aten.to at post-dispatch level can either be copy
889
+ # or alias. In the alias case, we will no-op it so it will
890
+ # disappear from the graph. If we detect such case, we should
891
+ # reuse the input to aten.to as the new input to the submodule.
892
+ # Technically this can happen for other maybe aliasing ops,
893
+ # but aten.to is probably the most common one.
894
+ elif x .name in old_nodes :
895
+ old_node = old_nodes [x .name ]
896
+ if old_node .op == "call_function" and old_node .target in [
897
+ torch .ops .aten .to .dtype_layout ,
898
+ torch .ops .aten .to .device ,
899
+ torch .ops .aten .to .dtype ,
900
+ ]:
901
+ old_target = old_node .args [0 ].name
902
+ if old_target not in provenance :
903
+ raise ValueError (
904
+ f"It looks like { old_target } is a tensor subclass. "
905
+ f"Preserving submodule that takes subclass parameter is not supported"
906
+ f" in inference IR because we desugar them, resulting in more tensors"
907
+ )
908
+
909
+ x .name = provenance [old_target ]
838
910
839
911
return new_module_call_graph
840
912
@@ -864,7 +936,10 @@ def _decompose_exported_program(
864
936
# new nodes due to decompositions. So we need to update these signatures
865
937
# in the decomposed exported program's module_call_graph.
866
938
new_module_call_graph = _get_updated_module_call_graph (
939
+ ep .graph_module ,
940
+ ep .graph_signature ,
867
941
gm ,
942
+ new_graph_signature ,
868
943
ep .module_call_graph ,
869
944
)
870
945
0 commit comments