Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dfeb6ba

Browse filesBrowse files
committed
[reland][dynamo] Mark a vt unspecialized nn module variable source earlier
Reland of #154780
1 parent 8f08f90 commit dfeb6ba
Copy full SHA for dfeb6ba

File tree

Expand file treeCollapse file tree

8 files changed

+43
-19
lines changed
Filter options
Expand file treeCollapse file tree

8 files changed

+43
-19
lines changed

‎benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv

Copy file name to clipboardExpand all lines: benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ hf_Bert_large,pass,0
138138

139139

140140

141-
hf_BigBird,pass,18
141+
hf_BigBird,pass,24
142142

143143

144144

‎benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv

Copy file name to clipboardExpand all lines: benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ hf_Bert_large,pass,0
122122

123123

124124

125-
hf_BigBird,pass,18
125+
hf_BigBird,pass,24
126126

127127

128128

‎test/dynamo/test_modules.py

Copy file name to clipboardExpand all lines: test/dynamo/test_modules.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,7 @@ def test_unsupportedmodule(self):
12991299
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
13001300
self.assertEqual(cnt.op_count, 6)
13011301

1302+
@patch.object(torch._dynamo.config, "allow_unspec_int_on_nn_module", True)
13021303
def test_self_mutating1(self):
13031304
m1 = torch.nn.Linear(10, 10)
13041305
m2 = SelfMutatingModule(m1)

‎test/functorch/test_control_flow.py

Copy file name to clipboardExpand all lines: test/functorch/test_control_flow.py
+2-3Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7454,14 +7454,13 @@ def forward(self, a, b):
74547454
self.assertExpectedInline(
74557455
backend.graphs[0].code.strip(),
74567456
"""\
7457-
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt):
7457+
def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
74587458
l_a_ = L_a_
74597459
l_b_ = L_b_
7460-
l_self_num = L_self_num
74617460
tensor = torch.tensor([True])
74627461
cond_true_0 = self.cond_true_0
74637462
cond_false_0 = self.cond_false_0
7464-
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, l_self_num, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = s97 = None
7463+
cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None
74657464
getitem = cond[0]; cond = None
74667465
return (getitem,)""", # noqa: B950
74677466
)

‎torch/_dynamo/utils.py

Copy file name to clipboardExpand all lines: torch/_dynamo/utils.py
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2402,6 +2402,10 @@ def is_int_specialization_case(value, source):
24022402
source.guard_source().is_unspecialized_builtin_nn_module()
24032403
and not config.allow_unspec_int_on_nn_module
24042404
)
2405+
or (
2406+
source.guard_source().is_unspecialized_nn_module()
2407+
and not config.allow_unspec_int_on_nn_module
2408+
)
24052409
or is_from_defaults(source)
24062410
# TODO: Delete this condition when rollout is done. NB: this
24072411
# condition never evaluates True in open source

‎torch/_dynamo/variables/builder.py

Copy file name to clipboardExpand all lines: torch/_dynamo/variables/builder.py
+30-7Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@
115115
Source,
116116
SubclassAttrListSource,
117117
TupleIteratorGetItemSource,
118+
UnspecializedBuiltinNNModuleSource,
119+
UnspecializedNNModuleSource,
118120
)
119121
from ..utils import (
120122
_extract_tensor_dict,
@@ -434,7 +436,10 @@ def __call__(self, value):
434436
return cached_vt
435437

436438
vt = self._wrap(value)
437-
vt.source = self.source
439+
440+
if vt.source is None:
441+
vt.source = self.source
442+
438443
if (
439444
self._can_lift_attrs_to_inputs(vt)
440445
and value not in self.tx.output.side_effects
@@ -1714,7 +1719,6 @@ def wrap_module(self, value: torch.nn.Module):
17141719
value = value.get_base()
17151720
self.source = AttrProxySource(self.source)
17161721

1717-
self.install_guards(GuardBuilder.TYPE_MATCH)
17181722
if torch._dynamo.config.inline_inbuilt_nn_modules:
17191723
freezing = is_parameter_freezing()
17201724

@@ -1749,12 +1753,27 @@ def wrap_module(self, value: torch.nn.Module):
17491753
# this will get cleaned up once compile ends
17501754
self.tx.output.nn_modules[self.name] = value
17511755

1752-
if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr(
1753-
value.__class__, "_dynamo_marked_static", False
1754-
):
1755-
result = UnspecializedBuiltinNNModuleVariable(value, source=self.source)
1756+
if (
1757+
value.__module__.startswith(("torch.nn.modules", "torch.ao."))
1758+
and not value.__module__.startswith("torch.nn.modules.container")
1759+
) or getattr(value.__class__, "_dynamo_marked_static", False):
1760+
new_source = self.source
1761+
if config.inline_inbuilt_nn_modules and (
1762+
not self.tx.output.export or config.install_free_tensors
1763+
):
1764+
# Export corner case - look at test_repros.py test_inlining_cornercase
1765+
new_source = UnspecializedBuiltinNNModuleSource(self.source)
1766+
result = UnspecializedBuiltinNNModuleVariable(value, source=new_source)
1767+
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
17561768
else:
1757-
result = UnspecializedNNModuleVariable(value, source=self.source)
1769+
new_source = self.source
1770+
if config.inline_inbuilt_nn_modules and (
1771+
not self.tx.output.export or config.install_free_tensors
1772+
):
1773+
# Export corner case - look at test_repros.py test_inlining_cornercase
1774+
new_source = UnspecializedNNModuleSource(self.source)
1775+
result = UnspecializedNNModuleVariable(value, source=new_source)
1776+
install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH))
17581777

17591778
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
17601779
# don't allow STORE_ATTR mutation with custom __setattr__
@@ -2127,6 +2146,10 @@ def wrap_numpy_ndarray(self, value):
21272146
)
21282147
proxy.node.meta["grapharg"] = grapharg
21292148

2149+
# TODO - Why do we need to set the source of the np ndarray vt back to
2150+
# original source. Many tests fails.
2151+
numpy_ndarray_variable.source = self.source
2152+
21302153
return numpy_ndarray_variable
21312154

21322155
def wrap_symint(

‎torch/_dynamo/variables/higher_order_ops.py

Copy file name to clipboardExpand all lines: torch/_dynamo/variables/higher_order_ops.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,8 +2658,8 @@ def call_function(
26582658

26592659
class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable):
26602660
def proxy_submod(self, tx, arg):
2661-
assert isinstance(arg.source, DictGetItemSource)
2662-
submod_name = tx.output.install_subgraph(arg.source.index, arg.value)
2661+
assert isinstance(arg.source.base, DictGetItemSource)
2662+
submod_name = tx.output.install_subgraph(arg.source.base.index, arg.value)
26632663
p_submod = make_attr(tx, submod_name)
26642664
set_example_value(p_submod.node, arg.value)
26652665
return p_submod

‎torch/_dynamo/variables/nn_module.py

Copy file name to clipboardExpand all lines: torch/_dynamo/variables/nn_module.py
+2-5Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
FSDPNNModuleSource,
4949
GetItemSource,
5050
NNModuleSource,
51-
UnspecializedBuiltinNNModuleSource,
5251
UnspecializedNNModuleSource,
5352
)
5453
from ..utils import (
@@ -891,8 +890,7 @@ def __init__(self, value, **kwargs) -> None:
891890
self.nn_module_stack_source = self.source
892891

893892
def _wrap_source(self, attr_source):
894-
if not isinstance(attr_source, UnspecializedNNModuleSource):
895-
return UnspecializedNNModuleSource(attr_source)
893+
# the vt is already wrapped with UnspecializedNNModuleSource
896894
return attr_source
897895

898896
def get_nn_module_stack_source(self):
@@ -1193,8 +1191,7 @@ class UnspecializedBuiltinNNModuleVariable(UnspecializedNNModuleVariable):
11931191
"""
11941192

11951193
def _wrap_source(self, attr_source):
1196-
if not isinstance(attr_source, UnspecializedBuiltinNNModuleSource):
1197-
return UnspecializedBuiltinNNModuleSource(attr_source)
1194+
# vt is already wrapped with the UnspecializedBuiltinNNModuleSource
11981195
return attr_source
11991196

12001197

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.