Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5fd7004

Browse filesBrowse files
Revert "[dynamo, nested graph breaks] remove block stack graph break in output_graph (#153772)"
This reverts commit 9a66c30. Reverted #153772 on behalf of https://github.com/malfet due to Not sure which one, but it broke test_error_messages, see https://hud.pytorch.org/hud/pytorch/pytorch/203b0efd6395a419330cf5700889afd420a89b2f/1?per_page=50&name_filter=py3.13-clang10&mergeEphemeralLF=true ([comment](#151056 (comment)))
1 parent e86439e commit 5fd7004
Copy full SHA for 5fd7004

File tree

Expand file treeCollapse file tree

1 file changed

+13
-3
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+13
-3
lines changed

‎torch/_dynamo/output_graph.py

Copy file name to clipboardExpand all lines: torch/_dynamo/output_graph.py
+13-3Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
from torch.utils._ordered_set import OrderedSet
6868
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
6969

70-
from . import config, exc, logging as torchdynamo_logging, variables
70+
from . import config, exc, graph_break_hints, logging as torchdynamo_logging, variables
7171
from .backends.registry import CompiledFn, CompilerFn
7272
from .bytecode_transformation import (
7373
create_call_function,
@@ -1253,6 +1253,18 @@ def compile_subgraph(
12531253

12541254
log.debug("COMPILING GRAPH due to %s", reason)
12551255

1256+
if not all(block.can_restore() for block in tx.block_stack):
1257+
unimplemented_v2(
1258+
gb_type="Attempt to compile graph with unrecoverable block in the block stack",
1259+
context="",
1260+
explanation="Dynamo does not support graph breaking on context managers in "
1261+
"nested function calls. For Python <= 3.10, this graph break may have instead been "
1262+
"caused by attempting to graph break in a try block.",
1263+
hints=[
1264+
*graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
1265+
],
1266+
)
1267+
12561268
# prefix instructions (Python 3.11+)
12571269
prefix_insts: list[Instruction] = []
12581270
if sys.version_info >= (3, 11):
@@ -1295,8 +1307,6 @@ def compile_subgraph(
12951307
cur_tx: Optional[InstructionTranslatorBase] = tx
12961308
while True:
12971309
assert cur_tx is not None
1298-
# this should have been checked by the caller
1299-
assert all(block.can_restore() for block in cur_tx.block_stack)
13001310
stack_values, restore_vars, meta = self._get_stack_values_to_restore(
13011311
cur_tx, stack_pops
13021312
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.