Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cd565bc

Browse filesBrowse files
jamesjwupytorchmergebot
authored andcommitted
Refactor process_inputs outside of create_aot_dispatcher_function (#130962)
This PR refactors process_inputs so that it occurs earlier outside of create_aot_dispatcher_function for the purpose of calculating a cache key with the inputs after they have been processed. This way, if tensors have symint sizes/strides, we successfully factor that into the cache key instead of specializing on every possible size and stride. Test that utilizes this incoming. # Guard behavior Note that it's technically possible for tensors with symint arguments to introduce guards in aot_dispatch, if they trace through decompositions that branch on tensor size/stride. This can result in multiple graph modules with differing guards having the same key in the cache. FXGraphCache has this same issue, and the remote FXGraphCache intentionally does not handle this: instead it only saves the first result in the cache, and cache misses if guards miss. The local FXGraphCache does handle this by storing multiple files and iterating through them, but we opt not to introduce that complexity just yet for AOTAutogradCache until we deem it necessary (i.e., models appear where saving multiple cache results with different guards but the same cache key becomes important). Instead, AOTAutogradCache will save a single entry per result, overriding it if it cache misses due to guards. Pull Request resolved: #130962 Approved by: https://github.com/bdhirsh
1 parent 4cca18d commit cd565bc
Copy full SHA for cd565bc

File tree

Expand file treeCollapse file tree

3 files changed

+242
-85
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+242
-85
lines changed

‎test/dynamo/test_aot_autograd_cache.py

Copy file name to clipboardExpand all lines: test/dynamo/test_aot_autograd_cache.py
+123-6Lines changed: 123 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,114 @@ def fn(a, b):
292292
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
293293
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
294294

295+
@largeTensorTest("64GB", device=GPU_TYPE)
296+
@parametrize("device", (GPU_TYPE,))
297+
@parametrize("dtype", (torch.float16, torch.bfloat16))
298+
@inductor_config.patch("fx_graph_cache", True)
299+
@inductor_config.patch("fx_graph_remote_cache", False)
300+
@functorch_config.patch({"enable_autograd_cache": True})
301+
def test_autograd_guard_single_entry(self, device, dtype):
302+
"""
303+
Test caching the same graph, but under conditions that introduce guards
304+
for tensor sizes < int32. See test_codecache::TestFxGraphCache::test_cache_load_with_guards_int32_bounds.
305+
306+
This test in particular tests the behavior of a single entry cache. If we ever make AOTAutogradCache
307+
support multiple entries under the same key, this test should be updated.
308+
"""
309+
if device == GPU_TYPE and not HAS_GPU:
310+
raise unittest.SkipTest(f"requires {GPU_TYPE}")
311+
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
312+
raise unittest.SkipTest("requires CUDA SM80 or later")
313+
314+
def fn(x, y):
315+
return (x + x, y + y)
316+
317+
def expect_miss(compiled_fn, a, b):
318+
self._clear_dynamo_and_codecache()
319+
counters.clear()
320+
res = compiled_fn(a, b)
321+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
322+
self.assertEqual(
323+
counters["aot_autograd"]["autograd_cache_guard_miss"],
324+
0,
325+
)
326+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
327+
return res
328+
329+
def expect_hit(compiled_fn, a, b):
330+
self._clear_dynamo_and_codecache()
331+
counters.clear()
332+
res = compiled_fn(a, b)
333+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
334+
self.assertEqual(
335+
counters["aot_autograd"]["autograd_cache_guard_miss"],
336+
0,
337+
)
338+
self.assertEqual(
339+
counters["aot_autograd"]["autograd_cache_hit"],
340+
1,
341+
)
342+
return res
343+
344+
def expect_guard_miss(compiled_fn, a, b):
345+
self._clear_dynamo_and_codecache()
346+
counters.clear()
347+
res = compiled_fn(a, b)
348+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
349+
self.assertEqual(
350+
counters["aot_autograd"]["autograd_cache_guard_miss"],
351+
1,
352+
)
353+
self.assertEqual(
354+
counters["aot_autograd"]["autograd_cache_hit"],
355+
0,
356+
)
357+
return res
358+
359+
compiled_fn = torch.compile(fn, dynamic=True)
360+
361+
a_shape = (5, 6)
362+
b_shape = (7, 8)
363+
a = torch.rand(a_shape, device=device, dtype=dtype)
364+
b = torch.rand(b_shape, device=device, dtype=dtype)
365+
res1 = expect_miss(compiled_fn, a, b)
366+
367+
# Same shape, should cache hit
368+
a2 = a.detach().clone()
369+
b2 = b.detach().clone()
370+
371+
res2 = expect_hit(compiled_fn, a2, b2)
372+
373+
self.assertEqual(res1, res2)
374+
375+
# By changing the shape greatly, despite the same exact input
376+
# graph, inductor should report a guard miss, leading
377+
# to a cache miss on our end.
378+
a_shape = (5, 6)
379+
b_shape = (47000, 47001)
380+
a3 = torch.rand(a_shape, device=device, dtype=dtype)
381+
b3 = torch.rand(b_shape, device=device, dtype=dtype)
382+
383+
expect_guard_miss(compiled_fn, a3, b3)
384+
385+
# Wobble the shape a bit, but not enough
386+
# to trigger a guard miss (since 6, 7 is still less than int32)
387+
# Should result in a cache hit
388+
a_shape = (6, 7)
389+
b_shape = (47000, 47001)
390+
a4 = torch.rand(a_shape, device=device, dtype=dtype)
391+
b4 = torch.rand(b_shape, device=device, dtype=dtype)
392+
expect_hit(compiled_fn, a4, b4)
393+
394+
# Change the shape back to the original,
395+
# FXGraphCache should hit because it stores
396+
# multiple entries
397+
a_shape = (5, 6)
398+
b_shape = (7, 8)
399+
a5 = torch.rand(a_shape, device=device, dtype=dtype)
400+
b5 = torch.rand(b_shape, device=device, dtype=dtype)
401+
expect_hit(compiled_fn, a5, b5)
402+
295403
@largeTensorTest("64GB", device=GPU_TYPE)
296404
@parametrize("device", (GPU_TYPE,))
297405
@parametrize("dtype", (torch.float16, torch.bfloat16))
@@ -301,7 +409,9 @@ def fn(a, b):
301409
@functorch_config.patch({"enable_autograd_cache": True})
302410
def test_autograd_inductor_guards(self, device, dtype, requires_grad):
303411
"""
304-
Tests that functions that would add inductor guards are cached properly
412+
Test caching the same graph, but under conditions that introduce guards
413+
for tensor sizes < int32.
414+
See test_codecache::TestFxGraphCache::test_cache_load_with_guards_int32_bounds.
305415
"""
306416
if device == GPU_TYPE and not HAS_GPU:
307417
raise unittest.SkipTest(f"requires {GPU_TYPE}")
@@ -323,6 +433,7 @@ def fn(x, y):
323433
((47000, 47001), (5, 6)),
324434
)
325435
expected_hits = expected_misses = expected_saves = 0
436+
expected_guard_misses = 0
326437
for a_shape, b_shape in shapes:
327438
a = torch.rand(
328439
a_shape, device=device, dtype=dtype, requires_grad=requires_grad
@@ -336,15 +447,15 @@ def fn(x, y):
336447
# see a recompilation (along with a cache miss).
337448
res1 = compiled_fn(a, b)
338449
# A first call should miss in the cache.
339-
# NOTE: Currently, this cache miss is *not* due to guards,
340-
# but instead because the AOTAutogradCache key calculation specializes on input shapes.
341-
# Once we allow tensors with symints as part of the cache key calculation, it will
342-
# instead cache miss because of guard failure.
343450
expected_misses += 1
344-
345451
self.assertEqual(
346452
counters["aot_autograd"]["autograd_cache_miss"], expected_misses
347453
)
454+
self.assertEqual(
455+
counters["aot_autograd"]["autograd_cache_guard_miss"],
456+
expected_guard_misses,
457+
)
458+
348459
self.assertEqual(
349460
counters["aot_autograd"]["autograd_cache_hit"], expected_hits
350461
)
@@ -375,6 +486,12 @@ def fn(x, y):
375486
self.assertEqual(
376487
counters["aot_autograd"]["autograd_cache_miss"], expected_misses
377488
)
489+
self.assertEqual(
490+
counters["aot_autograd"]["autograd_cache_guard_miss"],
491+
expected_guard_misses,
492+
)
493+
# First compile is a regular cache miss, subsequent are guard misses
494+
expected_guard_misses += 1
378495
self.assertEqual(
379496
counters["aot_autograd"]["autograd_cache_hit"], expected_hits
380497
)

‎torch/_functorch/_aot_autograd/autograd_cache.py

Copy file name to clipboardExpand all lines: torch/_functorch/_aot_autograd/autograd_cache.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,9 @@ def load(
473473
# Count missing the FXGraphCache as a miss not a bypass
474474
except FXGraphCacheMiss as e:
475475
counters["aot_autograd"]["autograd_cache_miss"] += 1
476+
# Special counter when we pass autograd cache but
477+
# fail when on inductor guards
478+
counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
476479
if config.strict_autograd_cache:
477480
raise e
478481
except BypassAOTAutogradCache as e:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.