[inductor] Fix bug handling output_strides in fx graph cache (#112041)
Summary: The current implementation is not properly attaching output strides to the tracing context when an fx graph is loaded from the cache. That bugs leads to assertion failures like `AssertionError: expected size 3==3, stride 1==9 at dim=1`. This change saves the output strides in the serialized object cached on disk and inserts them into the tracing context whether the graph is loaded from cache or compiled.
Test Plan:
* New unit test using resnet18 (which repros the problem)
* Ran the timm benchmark suite with `--training`
Differential Revision: [D50756653](https://our.internmc.facebook.com/intern/diff/D50756653)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112041
Approved by: https://github.com/ezyang
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index 1a30576..28f7fde 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -60,6 +60,19 @@
_run_codecache_test("fork")
+class MyModelConv2d(torch.nn.Module):
+ def __init__(self, dim=512):
+ super().__init__()
+ self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
+ self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ torch._dynamo.graph_break()
+ x = self.conv2(x)
+ return x
+
+
@instantiate_parametrized_tests
class TestFxGraphCache(TestCase):
@classmethod
@@ -124,33 +137,40 @@
@requires_triton()
@config.patch({"fx_graph_cache": True})
@parametrize("device", ("cuda", "cpu"))
- @parametrize("dtype", (torch.float32, torch.bfloat16))
+ @parametrize("dtype", (torch.float32, torch.float16))
def test_cache_load_model(self, device, dtype):
"""
Verify that we can populate and load models from the cache.
"""
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("requires CUDA")
- if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
- raise unittest.SkipTest("requires SM80 or later")
- model = MyModel().to(dtype=dtype, device=device)
+ def fn(mod, x):
+ mod.zero_grad()
+ mod(x).sum().backward()
+ return [p.grad for p in mod.parameters()]
- a = torch.rand(10, 10, dtype=dtype, device=device)
+ compiled_fn = torch.compile(fn, dynamic=False)
- compiled_model = torch.compile(model, dynamic=False)
+ mod = MyModelConv2d().to(device=device, dtype=dtype)
+ inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype)
- # A first call shold miss in the cache.
- self.assertEqual(model(a), compiled_model(a))
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
+ # The first call should see all cache misses.
+ counters.clear()
+ grads1 = compiled_fn(mod, inp)
+ self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
- # A second call should hit. (First reset so in-memory guards
+ # The second should see all hits. (First reset so in-memory guards
# don't prevent compilation).
+ counters.clear()
torch._dynamo.reset()
- self.assertEqual(model(a), compiled_model(a))
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
- self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
+ grads2 = compiled_fn(mod, inp)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
+ self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+ # And the results should be the same.
+ self.assertEqual(grads1, grads2)
class TestFxGraphCacheHashing(TestCase):
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 9193781..428555b 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -659,6 +659,7 @@
mutated_inputs: Set[str] = field(default_factory=set)
mutated_input_idxs: Set[int] = field(default_factory=set)
constants: Dict[str, torch.Tensor] = field(default_factory=dict)
+ output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
_boxed_call: Optional[bool] = None
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index b3b59f8..9cfd4cd 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -8,7 +8,17 @@
import warnings
from itertools import count
-from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ FrozenSet,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
from unittest import mock
from functorch.compile import min_cut_rematerialization_partition
@@ -384,6 +394,12 @@
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
+ # Return the output strides to the caller via TracingContext
+ context = torch._guards.TracingContext.get()
+ if context is not None and context.output_strides is not None:
+ assert len(context.output_strides) == 0
+ context.output_strides.extend(compiled_graph.output_strides)
+
if aot_mode:
return compiled_graph
@@ -582,20 +598,19 @@
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
- context = torch._guards.TracingContext.get()
- if context is not None and context.output_strides is not None:
- # Return the output strides to the caller via TracingContext
- assert len(context.output_strides) == 0
- assert graph.graph_outputs is not None
+ output_strides: List[Optional[Tuple[int, ...]]] = []
+ if graph.graph_outputs is not None:
+ # We'll put the output strides in the compiled graph so we
+ # can later return them to the caller via TracingContext
for out in graph.graph_outputs:
if hasattr(out, "layout"):
- context.output_strides.append(
+ output_strides.append(
tuple( # type: ignore[arg-type]
V.graph.sizevars.size_hint(s) for s in out.layout.stride
)
)
else:
- context.output_strides.append(None)
+ output_strides.append(None)
compiled_fn = graph.compile_to_fn()
@@ -615,6 +630,7 @@
mutated_inputs=graph.mutated_inputs,
mutated_input_idxs=set(graph.mutated_input_idxs),
constants=graph.constants,
+ output_strides=output_strides,
)
return compiled_graph