Revert "[inductor] Fix bug handling output_strides in fx graph cache (#112041)"
This reverts commit 3d2041b34210bef3902f6ba86881b38ac0fbc57e.
Reverted https://github.com/pytorch/pytorch/pull/112041 on behalf of https://github.com/ZainRizvi due to fbcode failures ([comment](https://github.com/pytorch/pytorch/pull/112041#issuecomment-1785929233))
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index 28f7fde..1a30576 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -60,19 +60,6 @@
_run_codecache_test("fork")
-class MyModelConv2d(torch.nn.Module):
- def __init__(self, dim=512):
- super().__init__()
- self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
- self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
-
- def forward(self, x):
- x = self.conv1(x)
- torch._dynamo.graph_break()
- x = self.conv2(x)
- return x
-
-
@instantiate_parametrized_tests
class TestFxGraphCache(TestCase):
@classmethod
@@ -137,40 +124,33 @@
@requires_triton()
@config.patch({"fx_graph_cache": True})
@parametrize("device", ("cuda", "cpu"))
- @parametrize("dtype", (torch.float32, torch.float16))
+ @parametrize("dtype", (torch.float32, torch.bfloat16))
def test_cache_load_model(self, device, dtype):
"""
Verify that we can populate and load models from the cache.
"""
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("requires CUDA")
+ if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
+ raise unittest.SkipTest("requires SM80 or later")
- def fn(mod, x):
- mod.zero_grad()
- mod(x).sum().backward()
- return [p.grad for p in mod.parameters()]
+ model = MyModel().to(dtype=dtype, device=device)
- compiled_fn = torch.compile(fn, dynamic=False)
+ a = torch.rand(10, 10, dtype=dtype, device=device)
- mod = MyModelConv2d().to(device=device, dtype=dtype)
- inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype)
+ compiled_model = torch.compile(model, dynamic=False)
- # The first call should see all cache misses.
- counters.clear()
- grads1 = compiled_fn(mod, inp)
- self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
+ # A first call shold miss in the cache.
+ self.assertEqual(model(a), compiled_model(a))
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
- # The second should see all hits. (First reset so in-memory guards
+ # A second call should hit. (First reset so in-memory guards
# don't prevent compilation).
- counters.clear()
torch._dynamo.reset()
- grads2 = compiled_fn(mod, inp)
- self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
- self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
-
- # And the results should be the same.
- self.assertEqual(grads1, grads2)
+ self.assertEqual(model(a), compiled_model(a))
+ self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
+ self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
class TestFxGraphCacheHashing(TestCase):
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 428555b..9193781 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -659,7 +659,6 @@
mutated_inputs: Set[str] = field(default_factory=set)
mutated_input_idxs: Set[int] = field(default_factory=set)
constants: Dict[str, torch.Tensor] = field(default_factory=dict)
- output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
_boxed_call: Optional[bool] = None
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 928574e..9510b19 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -8,17 +8,7 @@
import warnings
from itertools import count
-from typing import (
- Any,
- Callable,
- Dict,
- FrozenSet,
- List,
- Optional,
- Sequence,
- Tuple,
- Union,
-)
+from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
from unittest import mock
from functorch.compile import min_cut_rematerialization_partition
@@ -394,12 +384,6 @@
log.debug("FX codegen and compilation took %.3fs", time.time() - start)
- # Return the output strides to the caller via TracingContext
- context = torch._guards.TracingContext.get()
- if context is not None and context.output_strides is not None:
- assert len(context.output_strides) == 0
- context.output_strides.extend(compiled_graph.output_strides)
-
if aot_mode:
return compiled_graph
@@ -598,19 +582,20 @@
)
with V.set_graph_handler(graph):
graph.run(*example_inputs)
- output_strides: List[Optional[Tuple[int, ...]]] = []
- if graph.graph_outputs is not None:
- # We'll put the output strides in the compiled graph so we
- # can later return them to the caller via TracingContext
+ context = torch._guards.TracingContext.get()
+ if context is not None and context.output_strides is not None:
+ # Return the output strides to the caller via TracingContext
+ assert len(context.output_strides) == 0
+ assert graph.graph_outputs is not None
for out in graph.graph_outputs:
if hasattr(out, "layout"):
- output_strides.append(
+ context.output_strides.append(
tuple( # type: ignore[arg-type]
V.graph.sizevars.size_hint(s) for s in out.layout.stride
)
)
else:
- output_strides.append(None)
+ context.output_strides.append(None)
compiled_fn = graph.compile_to_fn()
@@ -630,7 +615,6 @@
mutated_inputs=graph.mutated_inputs,
mutated_input_idxs=set(graph.mutated_input_idxs),
constants=graph.constants,
- output_strides=output_strides,
)
return compiled_graph