Revert "[inductor] Fix bug handling output_strides in fx graph cache (#112041)"

This reverts commit 3d2041b34210bef3902f6ba86881b38ac0fbc57e.

Reverted https://github.com/pytorch/pytorch/pull/112041 on behalf of https://github.com/ZainRizvi due to fbcode failures ([comment](https://github.com/pytorch/pytorch/pull/112041#issuecomment-1785929233))
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index 28f7fde..1a30576 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -60,19 +60,6 @@
     _run_codecache_test("fork")
 
 
-class MyModelConv2d(torch.nn.Module):
-    def __init__(self, dim=512):
-        super().__init__()
-        self.conv1 = torch.nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
-        self.conv2 = torch.nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
-
-    def forward(self, x):
-        x = self.conv1(x)
-        torch._dynamo.graph_break()
-        x = self.conv2(x)
-        return x
-
-
 @instantiate_parametrized_tests
 class TestFxGraphCache(TestCase):
     @classmethod
@@ -137,40 +124,33 @@
     @requires_triton()
     @config.patch({"fx_graph_cache": True})
     @parametrize("device", ("cuda", "cpu"))
-    @parametrize("dtype", (torch.float32, torch.float16))
+    @parametrize("dtype", (torch.float32, torch.bfloat16))
     def test_cache_load_model(self, device, dtype):
         """
         Verify that we can populate and load models from the cache.
         """
         if device == "cuda" and not HAS_CUDA:
             raise unittest.SkipTest("requires CUDA")
+        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
+            raise unittest.SkipTest("requires SM80 or later")
 
-        def fn(mod, x):
-            mod.zero_grad()
-            mod(x).sum().backward()
-            return [p.grad for p in mod.parameters()]
+        model = MyModel().to(dtype=dtype, device=device)
 
-        compiled_fn = torch.compile(fn, dynamic=False)
+        a = torch.rand(10, 10, dtype=dtype, device=device)
 
-        mod = MyModelConv2d().to(device=device, dtype=dtype)
-        inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype)
+        compiled_model = torch.compile(model, dynamic=False)
 
-        # The first call should see all cache misses.
-        counters.clear()
-        grads1 = compiled_fn(mod, inp)
-        self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
+        # A first call shold miss in the cache.
+        self.assertEqual(model(a), compiled_model(a))
+        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
         self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
 
-        # The second should see all hits. (First reset so in-memory guards
+        # A second call should hit. (First reset so in-memory guards
         # don't prevent compilation).
-        counters.clear()
         torch._dynamo.reset()
-        grads2 = compiled_fn(mod, inp)
-        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
-        self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
-
-        # And the results should be the same.
-        self.assertEqual(grads1, grads2)
+        self.assertEqual(model(a), compiled_model(a))
+        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
+        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
 
 
 class TestFxGraphCacheHashing(TestCase):
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 428555b..9193781 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -659,7 +659,6 @@
     mutated_inputs: Set[str] = field(default_factory=set)
     mutated_input_idxs: Set[int] = field(default_factory=set)
     constants: Dict[str, torch.Tensor] = field(default_factory=dict)
-    output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
 
     _boxed_call: Optional[bool] = None
 
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 928574e..9510b19 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -8,17 +8,7 @@
 import warnings
 from itertools import count
 
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    FrozenSet,
-    List,
-    Optional,
-    Sequence,
-    Tuple,
-    Union,
-)
+from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
 from unittest import mock
 
 from functorch.compile import min_cut_rematerialization_partition
@@ -394,12 +384,6 @@
 
     log.debug("FX codegen and compilation took %.3fs", time.time() - start)
 
-    # Return the output strides to the caller via TracingContext
-    context = torch._guards.TracingContext.get()
-    if context is not None and context.output_strides is not None:
-        assert len(context.output_strides) == 0
-        context.output_strides.extend(compiled_graph.output_strides)
-
     if aot_mode:
         return compiled_graph
 
@@ -598,19 +582,20 @@
         )
         with V.set_graph_handler(graph):
             graph.run(*example_inputs)
-            output_strides: List[Optional[Tuple[int, ...]]] = []
-            if graph.graph_outputs is not None:
-                # We'll put the output strides in the compiled graph so we
-                # can later return them to the caller via TracingContext
+            context = torch._guards.TracingContext.get()
+            if context is not None and context.output_strides is not None:
+                # Return the output strides to the caller via TracingContext
+                assert len(context.output_strides) == 0
+                assert graph.graph_outputs is not None
                 for out in graph.graph_outputs:
                     if hasattr(out, "layout"):
-                        output_strides.append(
+                        context.output_strides.append(
                             tuple(  # type: ignore[arg-type]
                                 V.graph.sizevars.size_hint(s) for s in out.layout.stride
                             )
                         )
                     else:
-                        output_strides.append(None)
+                        context.output_strides.append(None)
 
             compiled_fn = graph.compile_to_fn()
 
@@ -630,7 +615,6 @@
                 mutated_inputs=graph.mutated_inputs,
                 mutated_input_idxs=set(graph.mutated_input_idxs),
                 constants=graph.constants,
-                output_strides=output_strides,
             )
     return compiled_graph