Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5621b72

Browse filesBrowse files
aorenstefacebook-github-bot
authored andcommitted
FakeTensor SymInt caching fixes (#154390)
Summary: Two fixes: 1. In the FakeTensor crosscheck assertion it should be checking that for SymInts the SymNode is the same object, not the SymInt itself. 2. If constructing the result tensor results in a data-dependent error then bypass the cache. Test Plan: Unit tests pass Differential Revision: D75421864
1 parent 0a7eef1 commit 5621b72
Copy full SHA for 5621b72

File tree

Expand file treeCollapse file tree

1 file changed

+13
-4
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+13
-4
lines changed

‎torch/_subclasses/fake_tensor.py

Copy file name to clipboardExpand all lines: torch/_subclasses/fake_tensor.py
+13-4Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,9 +1764,17 @@ def _get_output_info_for_cache_entry(
17641764
entry_for_synth_output = _DispatchCacheValidEntry(
17651765
output_infos=(entry,), is_output_tuple=False
17661766
)
1767-
synth_output = self._output_from_cache_entry(
1768-
state, entry_for_synth_output, key, func, args
1769-
)
1767+
1768+
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
1769+
1770+
try:
1771+
synth_output = self._output_from_cache_entry(
1772+
state, entry_for_synth_output, key, func, args
1773+
)
1774+
except GuardOnDataDependentSymNode as e:
1775+
# We can't cache this - it's data-dependent.
1776+
# An example: aten.select(t, 0, u0)
1777+
raise _BypassDispatchCache("data-dependent operation") from e
17701778

17711779
# Make sure the dispatch_key_set from the synthesized output tensor will
17721780
# be the same.
@@ -1975,7 +1983,7 @@ def assert_helper(a: Any, b: Any) -> None:
19751983
elif a is None:
19761984
assert b is None
19771985
elif isinstance(a, torch.SymInt):
1978-
assert a is b
1986+
assert isinstance(b, torch.SymInt) and a.node is b.node
19791987
elif isinstance(a, torch.Tensor):
19801988
assert isinstance(b, torch.Tensor)
19811989
assert_metadata_eq(assert_eq, a, b)
@@ -1992,6 +2000,7 @@ def assert_helper(a: Any, b: Any) -> None:
19922000
try:
19932001
assert_helper(true_output, output)
19942002
except Exception as e:
2003+
assert_helper(true_output, output)
19952004
raise RuntimeError(
19962005
f"FakeTensor cache crosscheck failure: func={func}, "
19972006
f"args={args}, kwargs={kwargs}"

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.