Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 9e30f87

Browse filesBrowse files
committed
[inductor][invoke_subgraph] Mark invoke_subgraph outputs as user_visible to constrain output strides
ghstack-source-id: 3e0ad90 Pull Request resolved: #155395
1 parent cc87d3a commit 9e30f87
Copy full SHA for 9e30f87

File tree

Expand file treeCollapse file tree

2 files changed

+147
-1
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+147
-1
lines changed

‎test/higher_order_ops/test_invoke_subgraph.py

Copy file name to clipboardExpand all lines: test/higher_order_ops/test_invoke_subgraph.py
+114Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,6 +2102,120 @@ def f(x, other):
21022102
self.assertEqual(f(x, other), f_compile(x, other))
21032103
self.assertTrue(called)
21042104

2105+
@requires_gpu
2106+
def test_preserves_output_strides(self):
2107+
# Have a graph pass that changes strides for the output op of the
2108+
# invoke_subgraph, and check if the output strides are preserved
2109+
import triton
2110+
import triton.language as tl
2111+
2112+
@triton.jit
2113+
def add_kernel(
2114+
in_ptr0,
2115+
in_ptr1,
2116+
out_ptr,
2117+
n_elements,
2118+
BLOCK_SIZE: "tl.constexpr",
2119+
):
2120+
pid = tl.program_id(axis=0)
2121+
block_start = pid * BLOCK_SIZE
2122+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
2123+
mask = offsets < n_elements
2124+
x = tl.load(in_ptr0 + offsets, mask=mask)
2125+
y = tl.load(in_ptr1 + offsets, mask=mask)
2126+
output = x + y
2127+
tl.store(out_ptr + offsets, output, mask=mask)
2128+
2129+
x = torch.randn(4, 4, 2, 2, device=GPU_TYPE)
2130+
other = torch.randn(4, 4, 2, 2, device=GPU_TYPE)
2131+
2132+
def add_triton(y, z):
2133+
grid = (z.numel(),)
2134+
out = torch.empty_like(z, memory_format=torch.contiguous_format)
2135+
add_kernel[grid](y, z, out, z.numel(), BLOCK_SIZE=16)
2136+
return out
2137+
2138+
class _CustomPass(PatternMatcherPass):
2139+
def __init__(self) -> None:
2140+
super().__init__()
2141+
2142+
def __call__(self, g: torch.fx.Graph):
2143+
self.apply(g)
2144+
2145+
g = _CustomPass()
2146+
called = False
2147+
2148+
@register_graph_pattern(
2149+
CallFunctionVarArgs(torch.ops.aten.permute),
2150+
pass_dict=g,
2151+
)
2152+
def _(match, *args, **kwargs):
2153+
flat_args, spec = pytree.tree_flatten((args, kwargs))
2154+
2155+
def decomp(*flat_args):
2156+
args, kwargs = pytree.tree_unflatten(flat_args, spec)
2157+
return torch.ops.mylib.force_channels_last(
2158+
torch.ops.aten.permute(*args, **kwargs)
2159+
)
2160+
2161+
nonlocal called
2162+
called = True
2163+
match.replace_by_example(decomp, flat_args)
2164+
2165+
from torch._inductor import config
2166+
2167+
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
2168+
lib.define(
2169+
"force_channels_last(Tensor x) -> Tensor",
2170+
tags=[torch._C.Tag.flexible_layout],
2171+
)
2172+
2173+
def impl2(x):
2174+
return x.clone(memory_format=torch.channels_last)
2175+
2176+
lib.impl("force_channels_last", impl2, "CompositeExplicitAutograd")
2177+
2178+
lib.define(
2179+
"add_op(Tensor x, Tensor y) -> Tensor",
2180+
)
2181+
2182+
def impl(x, y):
2183+
return add_triton(x, y)
2184+
2185+
def meta(x, y):
2186+
return torch.empty_like(y, memory_format=torch.contiguous_format)
2187+
2188+
lib.impl("add_op", impl, "CompositeExplicitAutograd")
2189+
lib.impl("add_op", meta, "Meta")
2190+
2191+
lib.define(
2192+
"add_out_op(Tensor x, Tensor y, Tensor(a!) out) -> ()",
2193+
)
2194+
2195+
def impl_out(x, y, out):
2196+
grid = (y.numel(),)
2197+
add_kernel[grid](x, y, out, y.numel(), BLOCK_SIZE=16)
2198+
2199+
lib.impl("add_out_op", impl_out, "CompositeExplicitAutograd")
2200+
lib.impl("add_out_op", lambda x, y, out: None, "Meta")
2201+
2202+
@mark_compile_region
2203+
def gn(x, other):
2204+
y = x.transpose(2, 3).contiguous().transpose(2, 3)
2205+
z = y.sin().transpose(2, 3)
2206+
return y, z
2207+
2208+
def f(x, other):
2209+
y, z = gn(x, other)
2210+
return torch.ops.mylib.add_op.default(y, z)
2211+
2212+
with config.patch(
2213+
post_grad_custom_post_pass=g,
2214+
):
2215+
f_compile = torch.compile(f, fullgraph=True)
2216+
self.assertEqual(f(x, other), f_compile(x, other))
2217+
self.assertTrue(called)
2218+
21052219

21062220
@skipIfTorchDynamo("Not a torch._dynamo test")
21072221
@parameterized_class(

‎torch/_inductor/compile_fx.py

Copy file name to clipboardExpand all lines: torch/_inductor/compile_fx.py
+33-1Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,33 @@ def record_original_output_strides(gm: GraphModule) -> None:
238238
output_node.meta["original_output_strides"] = output_strides
239239

240240

241+
def _recursive_record_original_output_strides(gm: GraphModule) -> None:
242+
# invoke_subgraph HOP requires output strides to be respected
243+
for node in gm.graph.find_nodes(
244+
op="call_function", target=torch.ops.higher_order.invoke_subgraph
245+
):
246+
subgraph = getattr(gm, node.args[0].target)
247+
_recursive_record_original_output_strides(subgraph)
248+
249+
record_original_output_strides(gm)
250+
251+
252+
def _recursive_record_user_visible_output_idxs(gm: GraphModule) -> None:
253+
# invoke_subgraph HOP requires output strides to be respected
254+
for node in gm.graph.find_nodes(
255+
op="call_function", target=torch.ops.higher_order.invoke_subgraph
256+
):
257+
subgraph = getattr(gm, node.args[0].target)
258+
259+
for node in subgraph.graph.find_nodes(op="output"):
260+
node.meta["user_visible_output_idxs"] = [
261+
idx
262+
for idx in range(len(node.args))
263+
if isinstance(node.args[0][idx], torch.fx.Node)
264+
]
265+
_recursive_record_user_visible_output_idxs(subgraph)
266+
267+
241268
@functools.lru_cache(None)
242269
def _step_logger() -> Callable[..., None]:
243270
return dynamo_logging.get_step_logger(log)
@@ -1167,7 +1194,7 @@ def codegen_and_compile(
11671194
with torch.no_grad():
11681195
fake_mode = fake_tensor_prop(gm, example_inputs)
11691196

1170-
record_original_output_strides(gm)
1197+
_recursive_record_original_output_strides(gm)
11711198

11721199
# pattern matcher passes might not preserve striding information
11731200
# on node.meta["val"]. if in the future we rely on these being
@@ -2211,6 +2238,11 @@ def fw_compiler_base(
22112238
else:
22122239
model_outputs_node.meta["user_visible_output_idxs"] = []
22132240

2241+
# We also mark the invoke_subgraph outputs as user_visible to
2242+
# force the outputs of invoke_subgraph subgraph to follow the
2243+
# original strides
2244+
_recursive_record_user_visible_output_idxs(gm)
2245+
22142246
return inner_compile(
22152247
gm,
22162248
example_inputs,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.