[inductor] FX graph cache: Add support for symbolic shapes (#111421)

Summary: Add support for caching graphs that have tensor args with symbolic shapes. The high-level appraoch is to serialize guards with the on-disk cached object and validating those guards pass before serving a cached object.

Test Plan: New unit tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111421
Approved by: https://github.com/ezyang
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index bad39cc..b2bbf01 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -17,6 +17,7 @@
     TensorMetadataAndValues,
 )
 from torch.testing._internal.common_cuda import SM80OrLater
+from torch.testing._internal.common_device_type import largeTensorTest
 from torch.testing._internal.common_utils import (
     instantiate_parametrized_tests,
     parametrize,
@@ -98,7 +99,8 @@
     @config.patch({"fx_graph_cache": True})
     @parametrize("device", ("cuda", "cpu"))
     @parametrize("dtype", (torch.float32, torch.bfloat16))
-    def test_cache_load_function(self, device, dtype):
+    @parametrize("dynamic", (False, True))
+    def test_cache_load_function(self, device, dtype, dynamic):
         """
         Verify that we can populate and load functions from the cache.
         """
@@ -114,7 +116,7 @@
         b = torch.rand(5, 5, dtype=dtype, device=device)
         c = a.view(5, 5)
 
-        compiled_fn = torch.compile(fn, dynamic=False)
+        compiled_fn = torch.compile(fn, dynamic=dynamic)
 
         # A first call shold miss in the cache.
         self.assertEqual(fn(a, b), compiled_fn(a, b))
@@ -138,7 +140,8 @@
     @config.patch({"fx_graph_cache": True})
     @parametrize("device", ("cuda", "cpu"))
     @parametrize("dtype", (torch.float32, torch.float64))
-    def test_cache_load_model(self, device, dtype):
+    @parametrize("dynamic", (False, True))
+    def test_cache_load_model(self, device, dtype, dynamic):
         """
         Verify that we can populate and load models from the cache.
         """
@@ -150,7 +153,7 @@
             mod(x).sum().backward()
             return [p.grad for p in mod.parameters()]
 
-        compiled_fn = torch.compile(fn, dynamic=False)
+        compiled_fn = torch.compile(fn, dynamic=dynamic)
 
         mod = MyModelConv2d().to(device=device, dtype=dtype)
         inp = torch.randn(2, 3, 16, 16, device=device, dtype=dtype)
@@ -172,6 +175,132 @@
         # And the results should be the same.
         self.assertEqual(grads1, grads2)
 
+    @largeTensorTest("64GB", device="cuda")
+    @config.patch({"fx_graph_cache": True})
+    @parametrize("device", ("cuda",))
+    @parametrize("dtype", (torch.float16, torch.bfloat16))
+    def test_cache_load_with_guards_int32_bounds(self, device, dtype):
+        """
+        Test caching the same graph, but under conditions that introduce guards
+        for tensor sizes < int32.
+        """
+        if device == "cuda" and not HAS_CUDA:
+            raise unittest.SkipTest("requires CUDA")
+        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
+            raise unittest.SkipTest("requires SM80 or later")
+
+        def fn(x, y):
+            return (x + x, y + y)
+
+        compiled_fn = torch.compile(fn, dynamic=True)
+
+        # Iterate over different shapes, varying whether the total
+        # size is below or above int32. For each combination, we expect
+        # different guards around whether the symbolic sizes do or do
+        # not exceed int32.
+        shapes = (
+            ((5, 6), (7, 8)),
+            ((5, 6), (47000, 47001)),
+            ((47000, 47001), (5, 6)),
+        )
+        for a_shape, b_shape in shapes:
+            a = torch.rand(a_shape, device=device, dtype=dtype)
+            b = torch.rand(b_shape, device=device, dtype=dtype)
+
+            # AVOID a dynamo reset here. We expect guards to have been
+            # added that will be violated with the new shape. We should
+            # see a recompilation (along with a cache miss).
+            counters.clear()
+            res1 = compiled_fn(a, b)
+            self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
+            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+            # A second call should hit. (Reset here to force compilation).
+            counters.clear()
+            torch._dynamo.reset()
+            res2 = compiled_fn(a, b)
+            self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
+            self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+            self.assertEqual(res1, res2)
+
+    @config.patch({"fx_graph_cache": True})
+    @parametrize("device", ("cuda", "cpu"))
+    @parametrize("dtype", (torch.float32, torch.bfloat16))
+    def test_cache_load_with_guards_static_bounds(self, device, dtype):
+        """
+        Test caching the same graph, but under conditions that introduce guards
+        for static bounds.
+        """
+        if device == "cuda" and not HAS_CUDA:
+            raise unittest.SkipTest("requires CUDA")
+        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
+            raise unittest.SkipTest("requires SM80 or later")
+
+        # See lowering; for all of the pooling operators, we always guard and
+        # make the height/width static.
+        def fn(x):
+            return torch.nn.functional.adaptive_avg_pool2d(x, [5, 7])
+
+        compiled_fn = torch.compile(fn, dynamic=True)
+
+        # Iterate over different input shapes. Each new shape should cause
+        # a cache miss.
+        shapes = ((1, 64, 8, 9), (1, 64, 9, 10), (1, 64, 10, 11))
+        for shape in shapes:
+            x = torch.rand(shape, device=device, dtype=dtype)
+
+            # AVOID a dynamo reset here. For each cache hit, we expect guards
+            # to have been added that will be violated with each new shape.
+            # We should see a recompilation (along with a cache miss).
+            counters.clear()
+            res1 = compiled_fn(x)
+            self.assertGreater(counters["inductor"]["fxgraph_cache_miss"], 0)
+            self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+            # A second call should hit.
+            counters.clear()
+            torch._dynamo.reset()
+            res2 = compiled_fn(x)
+            self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
+            self.assertGreater(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+            self.assertEqual(res1, res2)
+
+    @config.patch({"fx_graph_cache": True})
+    def test_cache_clear(self):
+        """
+        Test clearing the cache.
+        """
+
+        def fn(x, y):
+            return (x * y,)
+
+        a = torch.rand(5, 5)
+        b = torch.rand(5, 5)
+
+        compiled_fn = torch.compile(fn)
+
+        # A first call shold miss in the cache.
+        self.assertEqual(fn(a, b), compiled_fn(a, b))
+        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
+        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
+        # A second call should hit.
+        counters.clear()
+        torch._dynamo.reset()
+        self.assertEqual(fn(a, b), compiled_fn(a, b))
+        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
+        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
+
+        # Clear the cache; now we should miss.
+        counters.clear()
+        torch._dynamo.reset()
+        torch._inductor.codecache.FxGraphCache.clear()
+        self.assertEqual(fn(a, b), compiled_fn(a, b))
+        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
+        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
+
 
 class TestFxGraphCacheHashing(TestCase):
     def test_tensor_constants(self):
@@ -296,16 +425,16 @@
         ordering of the kwargs dict and any set arguments.
         """
         # Dict order of the kwargs should not affect hashes.
-        details1 = FxGraphHashDetails([], {"a": 0, "z": 1})
-        details2 = FxGraphHashDetails([], {"z": 1, "a": 0})
+        details1 = FxGraphHashDetails(None, [], {"a": 0, "z": 1})
+        details2 = FxGraphHashDetails(None, [], {"z": 1, "a": 0})
         self.assertEqual(
             FxGraphCachePickler.dumps(details1),
             FxGraphCachePickler.dumps(details2),
         )
 
         # Different kwarg values should affect hashes.
-        details1 = FxGraphHashDetails([], {"a": 0})
-        details2 = FxGraphHashDetails([], {"a": 1})
+        details1 = FxGraphHashDetails(None, [], {"a": 0})
+        details2 = FxGraphHashDetails(None, [], {"a": 1})
         self.assertNotEqual(
             FxGraphCachePickler.dumps(details1),
             FxGraphCachePickler.dumps(details2),
@@ -315,16 +444,16 @@
         # sorting and creating a new set seems to change the order.
         set1 = {"a", "b", "c", "d", "e", "f", "g"}
         set2 = set(sorted(set1))  # noqa: C414
-        details1 = FxGraphHashDetails([], {"a": set1})
-        details2 = FxGraphHashDetails([], {"a": set2})
+        details1 = FxGraphHashDetails(None, [], {"a": set1})
+        details2 = FxGraphHashDetails(None, [], {"a": set2})
         self.assertEqual(
             FxGraphCachePickler.dumps(details1),
             FxGraphCachePickler.dumps(details2),
         )
 
         # But different set contents should affect hashes.
-        details1 = FxGraphHashDetails([], {"a": {1, 2, 3}})
-        details2 = FxGraphHashDetails([], {"a": {1, 2}})
+        details1 = FxGraphHashDetails(None, [], {"a": {1, 2, 3}})
+        details2 = FxGraphHashDetails(None, [], {"a": {1, 2}})
         self.assertNotEqual(
             FxGraphCachePickler.dumps(details1),
             FxGraphCachePickler.dumps(details2),
@@ -335,11 +464,11 @@
         Test that different config settings affect hashes.
         """
         with config.patch({"max_autotune": False}):
-            details1 = FxGraphHashDetails([], {})
-            details2 = FxGraphHashDetails([], {})
+            details1 = FxGraphHashDetails(None, [], {})
+            details2 = FxGraphHashDetails(None, [], {})
 
         with config.patch({"max_autotune": True}):
-            details3 = FxGraphHashDetails([], {})
+            details3 = FxGraphHashDetails(None, [], {})
 
         self.assertEqual(
             FxGraphCachePickler.dumps(details1),
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 428555b..c92fe42 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -51,6 +51,7 @@
 from torch._inductor.codegen.cuda import cuda_env
 from torch._inductor.utils import developer_warning, is_linux
 from torch._prims_common import suggest_memory_format
+from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
 
 if TYPE_CHECKING:
     from torch._inductor.graph import GraphLowering
@@ -326,16 +327,16 @@
     return lock_dir
 
 
+def sha256_hash(data: bytes) -> str:
+    # [:51] to strip off the "Q====" suffix common to every hash value.
+    return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
+
+
 def code_hash(code: Union[str, bytes], extra: str = ""):
     hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
     if extra != "":
         hashing_str = hashing_str + b"||" + extra.encode("utf-8")
-    return (
-        "c"
-        + base64.b32encode(hashlib.sha256(hashing_str).digest())[:51]
-        .decode("utf-8")
-        .lower()
-    )
+    return "c" + sha256_hash(hashing_str)
 
 
 def get_path(
@@ -357,8 +358,6 @@
         return code_hash(content, extra)
     if hash_type == "cubin":
         return code_hash(repr(content))
-    if hash_type == "cg":
-        return extra
     raise AssertionError(f"Unknown hash type {hash_type}")
 
 
@@ -425,21 +424,10 @@
     values: List[Any]
 
 
-class DynamicShapeError(BaseException):
-    """
-    TODO(masnesral): Support dynamic shapes and remove this.
-    """
-
-    pass
-
-
 def extract_tensor_metadata(t: torch.Tensor) -> TensorMetadata:
     """
     Extract the TensorMetadata of a tensor.
     """
-    if not all(isinstance(e, int) for e in t.shape):
-        raise DynamicShapeError()
-
     memory_format = suggest_memory_format(t)
     if not t.is_contiguous(memory_format=memory_format):
         memory_format = None
@@ -491,6 +479,16 @@
         return (_ident, (metadata,))
 
 
+def _reduce_symint(s):
+    """
+    See FxGraphCachePickler. Custom reducer to pickle SymInts.
+    """
+    # For hashing purposes, we only care about the name of the symbol and
+    # not the backed value. We evaluate guards stored with a cached graph
+    # to ensure a cached entity with SymInt args is safe to reuse.
+    return (_ident, (str(s),))
+
+
 class FxGraphCachePickler(pickle.Pickler):
     """
     Custom pickler to customize the pickling of some objects (Tensors), only for the
@@ -502,6 +500,7 @@
     dispatch_table = copyreg.dispatch_table.copy()
     dispatch_table[torch._subclasses.fake_tensor.FakeTensor] = _reduce_fake_tensor
     dispatch_table[torch.Tensor] = _reduce_tensor
+    dispatch_table[torch.SymInt] = _reduce_symint
 
     @staticmethod
     def dumps(obj) -> bytes:
@@ -513,6 +512,15 @@
             pickler.dump(obj)
             return stream.getvalue()
 
+    @staticmethod
+    def get_hash(obj: Any) -> str:
+        """
+        Serialize an object using the FxGraphCachePickler and return a hash
+        of the pickled object.
+        """
+        serialized_data = FxGraphCachePickler.dumps(obj)
+        return sha256_hash(serialized_data)
+
 
 @functools.lru_cache(None)
 def get_inductor_code_hash() -> bytes:
@@ -553,8 +561,14 @@
     # Excluded kwargs param that are not stable between runs
     EXCLUDED_KWARGS = ["graph_id"]
 
-    def __init__(self, fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> None:
-        self.fx_args = fx_args
+    def __init__(
+        self,
+        gm: torch.fx.GraphModule,
+        example_inputs: List[torch.Tensor],
+        fx_kwargs: Dict[str, Any],
+    ):
+        self.gm = gm
+        self.example_inputs = example_inputs
 
         # Order kwargs so hashing is stable to changes in kwarg order.
         self.fx_kwargs = {}
@@ -575,79 +589,243 @@
         self.inductor_config = config.save_config()  # type: ignore[attr-defined]
         self.inductor_code_hash = get_inductor_code_hash()
 
+    def debug_str(self) -> str:
+        """
+        Get a printable string describing in more detail all the attributes
+        comprising this object. Useful for debugging when one graph hashes
+        to a different value than another.
+        """
 
-def compiled_fx_graph_hash(fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> str:
+        def get_str(obj) -> str:
+            if isinstance(obj, torch.Tensor):
+                return str(extract_tensor_metadata(obj))
+            elif isinstance(obj, bytes):
+                return "<bytes>"
+            else:
+                return str(obj)
+
+        lines = []
+        for attr, obj in vars(self).items():
+            if isinstance(obj, list):
+                for ii in range(len(obj)):
+                    h = FxGraphCachePickler.get_hash(obj[ii])
+                    lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}")
+            elif isinstance(obj, dict):
+                for k, v in obj.items():
+                    h = FxGraphCachePickler.get_hash(v)
+                    lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}")
+            else:
+                h = FxGraphCachePickler.get_hash(obj)
+                lines.append(f"[{h}] {attr}: {get_str(obj)}")
+        return "\n".join(lines)
+
+
+def compiled_fx_graph_hash(
+    gm: torch.fx.GraphModule,
+    example_inputs: List[torch.Tensor],
+    fx_kwargs: Dict[str, Any],
+) -> str:
     """
     Generate a unique hash of the FX graph for caching.
     """
-    details = FxGraphHashDetails(fx_args, fx_kwargs)
-    serialized_data = FxGraphCachePickler.dumps(details)
-    return (
-        "f"
-        + base64.b32encode(hashlib.sha256(serialized_data).digest())[:51]
-        .decode("utf-8")
-        .lower()
-    )
+    details = FxGraphHashDetails(gm, example_inputs, fx_kwargs)
+    # The prefix distinguishes among the other kinds of objects we
+    # cache in this module.
+    key = "f" + FxGraphCachePickler.get_hash(details)
+    log.debug("FX graph cache hash details for key %s:\n%s", key, details.debug_str())
+    return key
 
 
 class FxGraphCache:
     """
     Supports caching and reusing compiled Fx graphs.
+
+    The overall strategy is as follows:
+    - This cache stores entries on disk. When saving an entry, we can't
+      serialize callables (that could be C++, Triton, etc.), so we serialize
+      their own disk cache location. We then recreate the compiled artifact
+      after fetching from disk.
+    - For indexing the cache, we gather the fields relevant to identifying an
+      FxGraph (the graph module, graph inputs, system settings etc.) into an
+      FxGraphCacheDetails object, pickle it, and compute a hash for the key.
+      See FxGraphCachePickler.
+    - Among the metadata we store, we also include a guards expression that's
+      appropriate for validating any symbols for Tensor arguments that have
+      symbolic bounds. On cache lookup then, we evaluate those guards in the
+      current context to validate that a cached entry can be served.
+    - A given graph could have multiple compiled versions, corresponding to
+      different sets of guards. Therefore, we store cache entries in the form:
+          <temp dir>/<fx graph hash>/<serialized metatdata>
+    - On lookup, we compute the key from the graph details, iterate over all
+      leaf files in the corresponding subdirectory, deserialize the entry, and
+      evaluate its guards expression. If the evaluation succeeds, we have a
+      cache hit. If it fails, we compile the graph and store a new entry.
+    - Finally, on a cache hit, we need to make sure any guards that would
+      have been created during compilation are added to the current context.
     """
 
     # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs
     # in an in-memory cache after loading from disk.
-    @classmethod
-    def save_graph(cls, key: str, compiled_graph: CompiledFxGraph):
+    @staticmethod
+    def _get_tmp_dir() -> str:
+        """
+        Get the toplevel temporary directory for storing compiled graphs.
+        """
+        return os.path.join(cache_dir(), "fxgraph")
+
+    @staticmethod
+    def _get_tmp_dir_for_key(key: str) -> str:
+        """
+        Return the disk location for a given cache key.
+        """
+        return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
+
+    @staticmethod
+    def _filter_symints(inputs: List[Any]) -> List[torch.SymInt]:
+        """
+        Get the SymInt objects from the input list.
+        """
+        return [s for s in inputs if isinstance(s, torch.SymInt)]
+
+    @staticmethod
+    def _get_shape_env() -> ShapeEnv:
+        """
+        Helper to get the shape env from the tracing context.
+        """
+        tracing_context = torch._guards.TracingContext.get()
+        assert tracing_context is not None
+        return tracing_context.fake_mode.shape_env
+
+    @staticmethod
+    def _lookup_graph(
+        key: str,
+        example_inputs: List[torch.Tensor],
+    ) -> Optional[CompiledFxGraph]:
+        """
+        Lookup a compiled graph in the cache by key. On a hit, return the
+        deserialized CompiledFxGraph object. On a miss, return None.
+        """
+        subdir = FxGraphCache._get_tmp_dir_for_key(key)
+        if not os.path.exists(subdir):
+            return None
+
+        # Iterate over any entries in the subdir for this key and evaluate
+        # their guards to determine whether there's a hit.
+        for path in sorted(os.listdir(subdir)):
+            with open(os.path.join(subdir, path), "rb") as f:
+                graph: CompiledFxGraph = pickle.load(f)
+
+            guards_expr = graph.guards_expr
+            if not guards_expr:
+                # No guards to evaluate
+                return graph
+
+            # Evaluate the guard expression in the current context.
+            shape_env = FxGraphCache._get_shape_env()
+            symints = FxGraphCache._filter_symints(example_inputs)
+
+            # If there's not a cache hit, we don't want the evaluation to
+            # affect the current env, e.g., cause the creation of new guards,
+            # so we evaluate with the hints instead of the symbols.
+            assert all(has_hint(s) for s in symints)
+            hints = [hint_int(s) for s in symints]
+            hit = bool(shape_env.evaluate_guards_expression(guards_expr, hints))
+            log.debug(
+                "fx graph cache key %s evaluating guards for %s with values %s => %s",
+                key,
+                guards_expr,
+                hints,
+                hit,
+            )
+            if hit:
+                # Now re-evaluate with the symints to add any guards to the current env.
+                check = bool(shape_env.evaluate_guards_expression(guards_expr, symints))
+                assert check is True
+                log.debug(
+                    "fx graph cache key %s post-load guards: %s",
+                    key,
+                    shape_env.guards,
+                )
+                return graph
+
+        return None
+
+    @staticmethod
+    def _save_graph(
+        key: str, compiled_graph: CompiledFxGraph, example_inputs: List[torch.Tensor]
+    ):
+        """
+        Store a serialized CompiledFxGraph on disk.
+        """
         disk_compiled_graph = copy(compiled_graph)
-        # Important as compiled models are not pickleable
-        # TODO: Check status of PR #101651 as that might change the above statement
+        # Important as compiled models are not pickleable:
         disk_compiled_graph.compiled_artifact = None
-        write(pickle.dumps(disk_compiled_graph), "cg", extra=key, hash_type="cg")
 
-    @classmethod
-    def load_graph(cls, cg_path: str) -> CompiledFxGraph:
-        with open(cg_path, "rb") as f:
-            return pickle.load(f)
+        # Before serializing, compute the guard expression that will be used to
+        # ensure that a CompiledFxGraph is valid when loaded from the cache. It's
+        # sufficient to consider only the SymInt args to the fx graph since the
+        # Tensor shapes are already captured in the hash for the cache key. Any
+        # Tensor arg with a symbolic shape will have a SymInt arg for the graph.
+        shape_env = FxGraphCache._get_shape_env()
+        symints = FxGraphCache._filter_symints(example_inputs)
+        disk_compiled_graph.guards_expr = shape_env.produce_guards_expression(symints)
 
-    @classmethod
+        content = pickle.dumps(disk_compiled_graph)
+
+        subdir = FxGraphCache._get_tmp_dir_for_key(key)
+        if not os.path.exists(subdir):
+            os.makedirs(subdir, exist_ok=True)
+
+        # Use a hash of the serialized CompiledFxGraph to get a unique file
+        # name. The specific name doesn't matter since a lookup involves
+        # iterating over all entries in the parent subdir.
+        path = os.path.join(subdir, sha256_hash(content))
+        write_atomic(path, content)
+
+    @staticmethod
     def load(
-        cls,
         compile_fx_fn: Callable[..., Any],
-        fx_args: List[Any],
+        gm: torch.fx.GraphModule,
+        example_inputs: List[torch.Tensor],
         fx_kwargs: Dict[str, Any],
     ):
+        """
+        Load a compiled graph from the cache. If a cached entry does not exist,
+        compile the graph and save it to the cache.
+        """
         from filelock import FileLock
 
-        try:
-            key = compiled_fx_graph_hash(fx_args, fx_kwargs)
-        except DynamicShapeError:
-            # TODO(masnresral): support dynamic shapes
-            log.debug("fx graph cache skip due to dynamic shapes")
-            counters["inductor"]["fxgraph_cache_skip"] += 1
-            return compile_fx_fn(*fx_args, **fx_kwargs)
+        key = compiled_fx_graph_hash(gm, example_inputs, fx_kwargs)
 
         lock_dir = get_lock_dir()
         lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
         with lock:
-            _, _, cg_path = get_path(key, "cg")
-            if not os.path.exists(cg_path):
+            compiled_graph = FxGraphCache._lookup_graph(key, example_inputs)
+            if compiled_graph is None:
                 log.debug("fx graph cache miss for key %s", key)
                 counters["inductor"]["fxgraph_cache_miss"] += 1
-                compiled_graph: CompiledFxGraph = compile_fx_fn(*fx_args, **fx_kwargs)
-                cls.save_graph(key, compiled_graph)
+                compiled_graph = compile_fx_fn(gm, example_inputs, **fx_kwargs)
+                FxGraphCache._save_graph(key, compiled_graph, example_inputs)
             else:
-                # Load required info from disk; recreation of compiled model will be on first run
                 log.debug("fx graph cache hit for key %s", key)
                 counters["inductor"]["fxgraph_cache_hit"] += 1
-                compiled_graph = cls.load_graph(cg_path)
 
             return compiled_graph
 
+    @staticmethod
+    def clear():
+        """
+        Clear out the on-disk cache.
+        """
+        shutil.rmtree(FxGraphCache._get_tmp_dir())
+
 
 @dataclasses.dataclass
 class CompiledFxGraph:
-    """Class holding a compiled FX graph"""
+    """
+    Class holding a compiled FX graph. This is the object serialized on disk
+    to support FxGraph caching.
+    """
 
     compiled_artifact: Optional[Callable[..., Any]] = None
     current_callable: Optional[Callable[..., Any]] = None
@@ -660,9 +838,33 @@
     mutated_input_idxs: Set[int] = field(default_factory=set)
     constants: Dict[str, torch.Tensor] = field(default_factory=dict)
     output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
+    # This is a string representation of an expression we serialize
+    # with the object so the guards can be evaluated in a different
+    # context in order to verify the validity of serving a cached
+    # fx graph. The expression must be generated by:
+    # ShapeEnv.produce_guards_expression()
+    guards_expr: Optional[str] = None
 
     _boxed_call: Optional[bool] = None
 
+    def __init__(
+        self,
+        compiled_artifact: Optional[Callable[..., Any]],
+        graph: GraphLowering,
+        output_strides: List[Optional[Tuple[int, ...]]],
+    ):
+        self.compiled_artifact = compiled_artifact
+        self.cache_key = graph.cache_key
+        self.artifact_path = graph.cache_path
+        self.cache_linemap = graph.cache_linemap
+        self.device_types = graph.device_types
+        self.device_idxs = graph.device_idxs
+        self.mutated_inputs = graph.mutated_inputs
+        self.mutated_input_idxs = set(graph.mutated_input_idxs)
+        self.constants = graph.constants
+        self.output_strides = output_strides
+        self.guards_expr = None
+
     def __call__(self, inputs: List[Any]) -> Any:
         return self.get_current_callable()(inputs)
 
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 32e3b60..b6aa7f1 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -366,8 +366,7 @@
 
     # Inputs to fx_codegen_and_compile
     # Anything that affects codegen should go here, so if the signature
-    # of fx_codegen_and_compile changes, the list and dict should be updated accordingly
-    graph_args = [gm, example_inputs]
+    # of fx_codegen_and_compile changes, the dict should be updated accordingly
     graph_kwargs = {
         "cudagraphs": cudagraphs,
         "num_fixed": num_fixed,
@@ -385,11 +384,11 @@
 
     if config.fx_graph_cache and not aot_mode:
         compiled_graph = FxGraphCache.load(
-            fx_codegen_and_compile, graph_args, graph_kwargs
+            fx_codegen_and_compile, gm, example_inputs, graph_kwargs
         )
     else:
         compiled_graph = fx_codegen_and_compile(
-            *graph_args, **graph_kwargs  # type: ignore[arg-type]
+            gm, example_inputs, **graph_kwargs  # type: ignore[arg-type]
         )
 
     log.debug("FX codegen and compilation took %.3fs", time.time() - start)
@@ -620,18 +619,8 @@
             if graph.disable_cudagraphs:
                 BoxedBool.disable(cudagraphs)
 
-            compiled_graph = CompiledFxGraph(
-                compiled_artifact=compiled_fn,
-                cache_key=graph.cache_key,
-                artifact_path=graph.cache_path,
-                cache_linemap=graph.cache_linemap,
-                device_types=graph.device_types,
-                device_idxs=graph.device_idxs,
-                mutated_inputs=graph.mutated_inputs,
-                mutated_input_idxs=set(graph.mutated_input_idxs),
-                constants=graph.constants,
-                output_strides=output_strides,
-            )
+            compiled_graph = CompiledFxGraph(compiled_fn, graph, output_strides)
+
     return compiled_graph
 
 
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 7008b15..94b8297 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -2735,13 +2735,31 @@
         self._check_translation_validate()
         return exprs
 
-    def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
+    def produce_guards_expression(self, placeholders, ignore_static=True):
+        """
+        Expected to be used with evaluate_guards_expression(). Produces the guards
+        for the given placeholders and returns a string expression to be evaluated
+        by evaluate_guards_expression given concrete values for the placeholders.
+        """
         from torch._dynamo.source import LocalSource
-        arg_names = [f"t{i}" for i in range(len(args))]
+        arg_names = [f"t{i}" for i in range(len(placeholders))]
         guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static)
         if guards:
-            code = " and ".join(guards)
-            return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
+            return " and ".join(guards)
+        return None
+
+    def evaluate_guards_expression(self, code, args):
+        """
+        Expected to be used with produce_guards_expression(). Evaluates an expression
+        generated by produce_guards_expression for the given concrete args.
+        """
+        arg_names = [f"t{i}" for i in range(len(args))]
+        return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
+
+    def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
+        code = self.produce_guards_expression(placeholders, ignore_static=ignore_static)
+        if code:
+            return self.evaluate_guards_expression(code, args)
         return True
 
     def bind_symbols(self, placeholders, args):