Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit db49182

Browse filesBrowse files
anijain2305pytorchmergebot
authored andcommitted
[invoke_subgraph] Add logging (#155284)
Pull Request resolved: #155284 Approved by: https://github.com/zou3519 ghstack dependencies: #155270
1 parent 0f3f597 commit db49182
Copy full SHA for db49182

File tree

Expand file treeCollapse file tree

5 files changed

+72
-7
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+72
-7
lines changed

‎test/dynamo/test_logging.py

Copy file name to clipboardExpand all lines: test/dynamo/test_logging.py
+17Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,22 @@ def test_dynamo_debug_default_off_artifacts(self, records):
169169
self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0)
170170
self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0)
171171

172+
@make_logging_test(hierarchical_compile=True)
173+
def test_hierarchical_compile(self, records):
174+
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
175+
176+
@mark_compile_region
177+
def gn(x):
178+
return x * 2
179+
180+
def fn(x):
181+
return gn(x)
182+
183+
fn_opt = torch.compile(fn, backend="inductor")
184+
fn_opt(torch.ones(1000, 1000))
185+
fn_opt(torch.ones(1000, 1000))
186+
self.assertGreater(len(records), 0)
187+
172188
@make_logging_test()
173189
def test_dynamo_error(self, records):
174190
try:
@@ -960,6 +976,7 @@ def bar():
960976
"loop_tiling",
961977
"autotuning",
962978
"graph_region_expansion",
979+
"hierarchical_compile",
963980
}
964981
for name in torch._logging._internal.log_registry.artifact_names:
965982
if name not in exclusions:

‎torch/_dynamo/variables/higher_order_ops.py

Copy file name to clipboardExpand all lines: torch/_dynamo/variables/higher_order_ops.py
+26-6Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666

6767
log = logging.getLogger(__name__)
68+
hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
6869

6970

7071
def raise_hard_error_if_graph_break(reason):
@@ -261,7 +262,7 @@ def _check_supported_callable_arg(
261262
)
262263

263264

264-
def are_same_graph_modules(a_mod, b_mod, fake_mode):
265+
def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode):
265266
from torch._subclasses._fake_tensor_utils import _CacheKeyState
266267
from torch._subclasses.fake_tensor import extract_tensor_metadata
267268

@@ -322,21 +323,29 @@ def check_all_args(a_nodes, b_nodes):
322323
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
323324
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
324325
if not check_all_args(a_flat, b_flat):
325-
# print("call_function args failed")
326+
hc_log.debug(
327+
"%s: Graph comparison failed at node (call_function): %s",
328+
fn_name,
329+
a_node,
330+
)
326331
return False
327332
elif a_node.op == "call_method":
328333
if a_node.target != b_node.target:
329334
return False
330335
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
331336
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
332337
if not check_all_args(a_flat, b_flat):
333-
# print("call_method args failed")
338+
hc_log.debug(
339+
"%s: Graph comparison failed at node (call_method) : %s",
340+
fn_name,
341+
a_node,
342+
)
334343
return False
335344
elif a_node.op == "output":
336345
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
337346
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
338347
if not check_all_args(a_flat, b_flat):
339-
# print("output args failed")
348+
hc_log.debug("%s: Graph comparison failed at the output node", fn_name)
340349
return False
341350
elif a_node.op == "get_attr":
342351
a_attr = getattr(a_mod, a_node.target)
@@ -345,7 +354,7 @@ def check_all_args(a_nodes, b_nodes):
345354
if not isinstance(b_attr, torch.fx.GraphModule):
346355
return False
347356
# This is an example of a HOP inside a HOP
348-
if not are_same_graph_modules(a_attr, b_attr, fake_mode):
357+
if not are_same_graph_modules(fn_name, a_attr, b_attr, fake_mode):
349358
return False
350359
else:
351360
# TODO - write an example with tensor as a graph attribute in
@@ -3359,9 +3368,11 @@ def install_subgraph_in_output_graph(
33593368

33603369
if isinstance(fn_vt, UserFunctionVariable):
33613370
fn_id = id(fn_vt.get_function())
3371+
fn_name = fn_vt.get_function().__name__
33623372
else:
33633373
assert isinstance(fn_vt, UnspecializedNNModuleVariable)
33643374
fn_id = id(fn_vt.value.forward.__func__)
3375+
fn_name = fn_vt.value.forward.__name__
33653376
previously_installed_submodules = []
33663377
if invoke_subgraph_cache:
33673378
previously_installed_submodules = (
@@ -3373,12 +3384,21 @@ def install_subgraph_in_output_graph(
33733384
for submodule_name in reversed(previously_installed_submodules):
33743385
assert submodule_name in tx.output.nn_modules
33753386
previous_mod = tx.output.nn_modules[submodule_name]
3376-
if are_same_graph_modules(previous_mod, current_mod, tx.fake_mode):
3387+
if are_same_graph_modules(
3388+
fn_name, previous_mod, current_mod, tx.fake_mode
3389+
):
33773390
return submodule_name
33783391

33793392
body_name = super().install_subgraph_in_output_graph(
33803393
tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph"
33813394
)
3395+
hc_log.debug(
3396+
"%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s",
3397+
fn_name,
3398+
body_name,
3399+
fn_name,
3400+
len(previously_installed_submodules) + 1,
3401+
)
33823402
if invoke_subgraph_cache:
33833403
invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name)
33843404

‎torch/_logging/_internal.py

Copy file name to clipboardExpand all lines: torch/_logging/_internal.py
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def set_logs(
251251
autotuning: bool = False,
252252
graph_region_expansion: bool = False,
253253
inductor_metrics: bool = False,
254+
hierarchical_compile: bool = False,
254255
) -> None:
255256
"""
256257
Sets the log level for individual components and toggles individual log
@@ -448,6 +449,9 @@ def set_logs(
448449
inductor_metrics (:class:`bool`):
449450
Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False``
450451
452+
hierarchical_compile (:class:`bool`):
453+
Whether to emit debug info for hierarchical compilation. Default: ``False``
454+
451455
Example::
452456
453457
>>> # xdoctest: +SKIP
@@ -560,6 +564,7 @@ def _set_logs(**kwargs) -> None:
560564
autotuning=autotuning,
561565
graph_region_expansion=graph_region_expansion,
562566
inductor_metrics=inductor_metrics,
567+
hierarchical_compile=hierarchical_compile,
563568
)
564569

565570

‎torch/_logging/_registrations.py

Copy file name to clipboardExpand all lines: torch/_logging/_registrations.py
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,5 +233,9 @@
233233
"Logs Inductor metrics, such as num_bytes, nodes_num_elem, node_runtimes",
234234
off_by_default=True,
235235
)
236-
236+
register_artifact(
237+
"hierarchical_compile",
238+
"Logs debug info for hierarchical compilation",
239+
off_by_default=True,
240+
)
237241
register_artifact("custom_format_test_artifact", "Testing only", log_format="")

‎torch/_subclasses/fake_tensor.py

Copy file name to clipboardExpand all lines: torch/_subclasses/fake_tensor.py
+19Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
6363

6464
log = logging.getLogger(__name__)
65+
hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
6566

6667
# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186
6768
# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105
@@ -1433,6 +1434,15 @@ def _cached_dispatch_impl(
14331434
key = self._cache_key(state, func, args, kwargs)
14341435
except _BypassDispatchCache as e:
14351436
# We couldn't create the cache key at all
1437+
if (
1438+
isinstance(func, torch._ops.HigherOrderOperator)
1439+
and func.name() == "invoke_subgraph"
1440+
):
1441+
hc_log.debug(
1442+
"Fake tensor cache failed: identifier = %s, reason = %s",
1443+
args[1],
1444+
e.reason,
1445+
)
14361446
FakeTensorMode.cache_bypasses[e.reason] += 1
14371447

14381448
if key is None:
@@ -1477,6 +1487,15 @@ def _cached_dispatch_impl(
14771487
# We ran "extra" checks on the cache key and determined that it's no
14781488
# good. Record the reason and mark it so we don't bother validating
14791489
# again.
1490+
if (
1491+
isinstance(func, torch._ops.HigherOrderOperator)
1492+
and func.name() == "invoke_subgraph"
1493+
):
1494+
hc_log.debug(
1495+
"Fake tensor cache failed: identifier = %s, reason = %s",
1496+
args[1],
1497+
e.reason,
1498+
)
14801499
FakeTensorMode.cache_bypasses[e.reason] += 1
14811500
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
14821501
return output

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.