Revert "[inductor] Implement Fx graph caching to improve warm compilation time. (#103453)"

This reverts commit fc1105b2827ee2febc85a3c353470edfd70a66ed.

Reverted https://github.com/pytorch/pytorch/pull/103453 on behalf of https://github.com/kit1980 due to Same issue unfortunately, the newly added test fails on internal builds ([comment](https://github.com/pytorch/pytorch/pull/103453#issuecomment-1760202365))
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index 8403d0d..e293afb 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -1,32 +1,12 @@
 # Owner(s): ["module: inductor"]
 import functools
-import pickle
-import tempfile
 import unittest
-from unittest.mock import patch
 
 import torch
-from torch._dynamo.test_case import run_tests, TestCase
-from torch._dynamo.utils import counters
-from torch._inductor import config
-from torch._inductor.codecache import (
-    AsyncCompile,
-    FxGraphCachePickler,
-    FxGraphHashDetails,
-    TensorMetadata,
-    TensorMetadataAndValues,
-)
-from torch.testing._internal.common_utils import (
-    instantiate_parametrized_tests,
-    parametrize,
-)
+from torch._inductor.codecache import AsyncCompile
 from torch.testing._internal.inductor_utils import HAS_CUDA
-from torch.utils._triton import has_triton
-
-HAS_TRITON = has_triton()
 
 requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
-requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
 
 
 class MyModel(torch.nn.Module):
@@ -57,274 +37,3 @@
 @requires_cuda()
 def test_codecache_fork():
     _run_codecache_test("fork")
-
-
-@instantiate_parametrized_tests
-class TestFxGraphCache(TestCase):
-    @classmethod
-    def setUpClass(cls):
-        # Reroute all cache disk activity to a clean temporary directory to
-        # ensure isolation (and initial cache misses). Deliberately create the
-        # temp dir in setUpClass, however, so that individual test runs reuse
-        # the same location. We don't expect different tests to reuse cache
-        # entries, so preserving the temp dir provides that additional testing.
-        cls.tmpdir = tempfile.TemporaryDirectory()
-        cls.cache_dir_patch = patch("torch._inductor.codecache.cache_dir")
-        cls.cache_dir_patch.start().return_value = cls.tmpdir.name
-
-    @classmethod
-    def tearDownClass(cls):
-        cls.cache_dir_patch.stop()
-        cls.tmpdir.cleanup()
-
-    def setUp(self):
-        counters.clear()
-
-    @requires_triton()
-    @config.patch({"fx_graph_cache": True})
-    @parametrize("device", ("cuda", "cpu"))
-    @parametrize("dtype", (torch.float, torch.bfloat16))
-    def test_cache_load_function(self, device, dtype):
-        """
-        Verify that we can populate and load functions from the cache.
-        """
-        if device == "cuda" and not HAS_CUDA:
-            raise unittest.SkipTest("requires CUDA")
-
-        def fn(x, y):
-            return (x * 2, y @ y)
-
-        a = torch.rand(25, dtype=dtype, device=device)
-        b = torch.rand(5, 5, dtype=dtype, device=device)
-        c = a.view(5, 5)
-
-        compiled_fn = torch.compile(fn, dynamic=False)
-
-        # A first call shold miss in the cache.
-        self.assertEqual(fn(a, b), compiled_fn(a, b))
-        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
-        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
-
-        # A second call should hit. (First reset so in-memory guards
-        # don't prevent compilation).
-        torch._dynamo.reset()
-        self.assertEqual(fn(a, b), compiled_fn(a, b))
-        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
-        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
-
-        # But we expect different code if the tensors are aliased.
-        torch._dynamo.reset()
-        self.assertEqual(fn(a, c), compiled_fn(a, c))
-        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
-        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
-
-    @requires_triton()
-    @config.patch({"fx_graph_cache": True})
-    @parametrize("device", ("cuda", "cpu"))
-    @parametrize("dtype", (torch.float, torch.float16))
-    def test_cache_load_model(self, device, dtype):
-        """
-        Verify that we can populate and load models from the cache.
-        """
-        if device == "cuda" and not HAS_CUDA:
-            raise unittest.SkipTest("requires CUDA")
-
-        model = MyModel().to(dtype=dtype, device=device)
-
-        a = torch.rand(10, 10, dtype=dtype, device=device)
-
-        compiled_model = torch.compile(model, dynamic=False)
-
-        # A first call shold miss in the cache.
-        self.assertEqual(model(a), compiled_model(a))
-        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
-        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
-
-        # A second call should hit. (First reset so in-memory guards
-        # don't prevent compilation).
-        torch._dynamo.reset()
-        self.assertEqual(model(a), compiled_model(a))
-        self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
-        self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
-
-
-class TestFxGraphCacheHashing(TestCase):
-    def test_tensor_constants(self):
-        """
-        Test the handling of small vs. large tensor constants.
-        """
-        data = FxGraphCachePickler.dumps(torch.tensor(list(range(9))))
-        self.assertIsInstance(pickle.loads(data), TensorMetadata)
-
-        data = FxGraphCachePickler.dumps(torch.tensor(list(range(8))))
-        self.assertIsInstance(pickle.loads(data), TensorMetadataAndValues)
-
-    def test_hash_fake_tensors(self):
-        """
-        Test hashing (pickling) FakeTensors with various characteristics.
-        """
-        with torch._subclasses.FakeTensorMode():
-            # Verify that FakeTensors get pickled into a TensorMetadata:
-            data = FxGraphCachePickler.dumps(torch.randn(1))
-            self.assertIsInstance(pickle.loads(data), TensorMetadata)
-
-            # Different shapes:
-            self.assertEqual(
-                FxGraphCachePickler.dumps(torch.randn(3)),
-                FxGraphCachePickler.dumps(torch.randn(3)),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3)),
-                FxGraphCachePickler.dumps(torch.randn(4)),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3)),
-                FxGraphCachePickler.dumps(torch.randn(3, 3)),
-            )
-
-            self.assertEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, 3)),
-                FxGraphCachePickler.dumps(torch.randn(3, 3)),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, 3)),
-                FxGraphCachePickler.dumps(torch.randn(3, 4)),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, 3)),
-                FxGraphCachePickler.dumps(torch.randn(4, 3)),
-            )
-
-            # Different strides:
-            self.assertEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, 3)),
-                FxGraphCachePickler.dumps(
-                    torch.randn(3, 3).transpose(0, 1).transpose(0, 1)
-                ),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, 3)),
-                FxGraphCachePickler.dumps(torch.randn(3, 3).transpose(0, 1)),
-            )
-
-            # Different storage offsets:
-            self.assertEqual(
-                FxGraphCachePickler.dumps(torch.randn(3)[1:]),
-                FxGraphCachePickler.dumps(torch.randn(3)[1:]),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3)[1:]),
-                FxGraphCachePickler.dumps(torch.randn(2)),
-            )
-
-            # Different dtypes:
-            self.assertEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
-                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float32)),
-                FxGraphCachePickler.dumps(torch.randn(3, dtype=torch.float64)),
-            )
-
-            # Different 'requires_grad':
-            self.assertEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
-                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=True)),
-                FxGraphCachePickler.dumps(torch.randn(3, requires_grad=False)),
-            )
-
-            # Different memory formats:
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(1, 2, 3, 4)),
-                FxGraphCachePickler.dumps(
-                    torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last)
-                ),
-            )
-
-            # Different devices:
-            self.assertEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
-                FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
-            )
-            self.assertNotEqual(
-                FxGraphCachePickler.dumps(torch.randn(3, device="meta")),
-                FxGraphCachePickler.dumps(torch.randn(3, device="cpu")),
-            )
-
-            if HAS_CUDA and torch.cuda.device_count() >= 2:
-                self.assertEqual(
-                    FxGraphCachePickler.dumps(torch.randn(3, device="cuda:1")),
-                    FxGraphCachePickler.dumps(torch.randn(3, device="cuda:1")),
-                )
-                self.assertNotEqual(
-                    FxGraphCachePickler.dumps(torch.randn(3, device="cuda:0")),
-                    FxGraphCachePickler.dumps(torch.randn(3, device="cuda:1")),
-                )
-
-    def test_hash_kwargs(self):
-        """
-        Test the special handling of the kwargs when hashing, i.e.,
-        ordering of the kwargs dict and any set arguments.
-        """
-        # Dict order of the kwargs should not affect hashes.
-        details1 = FxGraphHashDetails([], {"a": 0, "z": 1})
-        details2 = FxGraphHashDetails([], {"z": 1, "a": 0})
-        self.assertEqual(
-            FxGraphCachePickler.dumps(details1),
-            FxGraphCachePickler.dumps(details2),
-        )
-
-        # Different kwarg values should affect hashes.
-        details1 = FxGraphHashDetails([], {"a": 0})
-        details2 = FxGraphHashDetails([], {"a": 1})
-        self.assertNotEqual(
-            FxGraphCachePickler.dumps(details1),
-            FxGraphCachePickler.dumps(details2),
-        )
-
-        # Set order should not affect hashes. Sets are unordered, but
-        # sorting and creating a new set seems to change the order.
-        set1 = {"a", "b", "c", "d", "e", "f", "g"}
-        set2 = set(sorted(set1))  # noqa: C414
-        details1 = FxGraphHashDetails([], {"a": set1})
-        details2 = FxGraphHashDetails([], {"a": set2})
-        self.assertEqual(
-            FxGraphCachePickler.dumps(details1),
-            FxGraphCachePickler.dumps(details2),
-        )
-
-        # But different set contents should affect hashes.
-        details1 = FxGraphHashDetails([], {"a": {1, 2, 3}})
-        details2 = FxGraphHashDetails([], {"a": {1, 2}})
-        self.assertNotEqual(
-            FxGraphCachePickler.dumps(details1),
-            FxGraphCachePickler.dumps(details2),
-        )
-
-    def test_hash_config_changes(self):
-        """
-        Test that different config settings affect hashes.
-        """
-        with config.patch({"max_autotune": False}):
-            details1 = FxGraphHashDetails([], {})
-            details2 = FxGraphHashDetails([], {})
-
-        with config.patch({"max_autotune": True}):
-            details3 = FxGraphHashDetails([], {})
-
-        self.assertEqual(
-            FxGraphCachePickler.dumps(details1),
-            FxGraphCachePickler.dumps(details2),
-        )
-        self.assertNotEqual(
-            FxGraphCachePickler.dumps(details1),
-            FxGraphCachePickler.dumps(details3),
-        )
-
-
-if __name__ == "__main__":
-    run_tests()
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 926e45f..4dde4c4 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -1,20 +1,16 @@
 from __future__ import annotations
 
 import base64
-import copyreg
 import dataclasses
 import functools
 import getpass
 import hashlib
 import importlib
-import io
 import json
 import logging
 import multiprocessing
 import os
 import pathlib
-import pickle
-import pkgutil
 import platform
 import re
 import shlex
@@ -29,7 +25,6 @@
 import weakref
 from bisect import bisect_right
 from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
-from copy import copy
 from ctypes import c_void_p, cdll, CDLL
 from dataclasses import field
 from functools import partial
@@ -46,11 +41,9 @@
     get_interface_for_device,
     get_registered_device_interfaces,
 )
-from torch._dynamo.utils import counters
 from torch._inductor import config, exc
 from torch._inductor.codegen.cuda import cuda_env
 from torch._inductor.utils import developer_warning, is_linux
-from torch._prims_common import suggest_memory_format
 
 if TYPE_CHECKING:
     from torch._inductor.graph import GraphLowering
@@ -357,8 +350,6 @@
         return code_hash(content, extra)
     if hash_type == "cubin":
         return code_hash(repr(content))
-    if hash_type == "cg":
-        return extra
     raise AssertionError(f"Unknown hash type {hash_type}")
 
 
@@ -393,259 +384,6 @@
 
 
 @dataclasses.dataclass
-class TensorMetadata:
-    """
-    The Tensor metadata relevant when hashing FxGraph cache keys.
-    """
-
-    dtype: torch.dtype
-    shape: torch.Size
-    stride: Tuple[Any, ...]
-    device: torch.device
-    layout: torch.layout
-    memory_format: Optional[torch.memory_format]
-    storage_offset: int
-    requires_grad: bool
-    is_quantized: bool
-    is_conj: bool
-    is_neg: bool
-    is_coalesced: bool
-    dense_dim: int
-    sparse_dim: int
-
-
-@dataclasses.dataclass
-class TensorMetadataAndValues:
-    """
-    TensorMetadata plus the elements as a list of raw values.
-    Used for hashing inlined constants.
-    """
-
-    tensor_metadata: TensorMetadata
-    values: List[Any]
-
-
-class DynamicShapeError(BaseException):
-    """
-    TODO(masnesral): Support dynamic shapes and remove this.
-    """
-
-    pass
-
-
-def extract_tensor_metadata(t: torch.Tensor) -> TensorMetadata:
-    """
-    Extract the TensorMetadata of a tensor.
-    """
-    if not all(isinstance(e, int) for e in t.shape):
-        raise DynamicShapeError()
-
-    memory_format = suggest_memory_format(t)
-    if not t.is_contiguous(memory_format=memory_format):
-        memory_format = None
-
-    return TensorMetadata(
-        dtype=t.dtype,
-        shape=t.shape,
-        stride=t.stride() if t.layout == torch.strided else (),
-        device=t.device,
-        layout=t.layout,
-        memory_format=memory_format,
-        storage_offset=t.storage_offset(),
-        requires_grad=t.requires_grad,
-        is_quantized=t.is_quantized,
-        is_conj=t.is_conj(),
-        is_neg=t.is_neg(),
-        is_coalesced=t.is_coalesced() if t.is_sparse else False,
-        dense_dim=t.dense_dim() if t.is_sparse else False,
-        sparse_dim=t.sparse_dim() if t.is_sparse else False,
-    )
-
-
-def _ident(x: Any) -> Any:
-    return x
-
-
-def _reduce_fake_tensor(t):
-    """
-    See FxGraphCachePickler. Custom reducer to pickle FakeTensors.
-    """
-    metadata = extract_tensor_metadata(t)
-    return (_ident, (metadata,))
-
-
-def _reduce_tensor(t):
-    """
-    See FxGraphCachePickler. Custom reducer to pickle Tensors.
-    """
-    # If we see tensors, we know they're contstants stored as attributes on
-    # the GraphModule. See tensor lowering; small constants are inlined. If
-    # we see a small tensor, therefore, no reference will ultimately remain
-    # in the generated code. So we need to include its value in the cache key.
-    # Large constannts are effectively treated as inputs and we consider only
-    # their metadata.
-    metadata = extract_tensor_metadata(t)
-    if len(t.shape) == 0 or torch._inductor.graph.GraphLowering.can_inline_constant(t):
-        return (_ident, (TensorMetadataAndValues(metadata, t.tolist()),))
-    else:
-        return (_ident, (metadata,))
-
-
-class FxGraphCachePickler(pickle.Pickler):
-    """
-    Custom pickler to customize the pickling of some objects (Tensors), only for the
-    purpose of computing a hash for keying into the FxGraphCache. Tensors contain
-    objects that don't pickle and/or vary between runs, and we want to capture the
-    data that allow us to compute a stable, but safe hash.
-    """
-
-    dispatch_table = copyreg.dispatch_table.copy()
-    dispatch_table[torch._subclasses.fake_tensor.FakeTensor] = _reduce_fake_tensor
-    dispatch_table[torch.Tensor] = _reduce_tensor
-
-    @staticmethod
-    def dumps(obj) -> bytes:
-        """
-        Pickle an object using the FxGraphCachePickler.
-        """
-        with io.BytesIO() as stream:
-            pickler = FxGraphCachePickler(stream)
-            pickler.dump(obj)
-            return stream.getvalue()
-
-
-@functools.lru_cache(None)
-def get_inductor_code_hash() -> bytes:
-    """
-    Compute a hash of all inductor code modules. Used by the FxGraph cache
-    so any inductor code changes would result in new cache keys.
-    """
-    inductor_root = os.path.dirname(__file__)
-
-    contents: Dict[str, bytes] = {}
-    for lib in pkgutil.iter_modules([inductor_root]):
-        spec = lib.module_finder.find_spec(lib.name, None)
-        assert spec is not None
-        module = spec.origin
-        assert module is not None
-        with open(module, "rb") as f:
-            contents[module] = f.read()
-
-    return hashlib.sha256(pickle.dumps(contents)).digest()
-
-
-@dataclasses.dataclass
-class OrderedSetHolder:
-    """
-    See FxGraphHashDetails. Holds a sorted list to support stable hashing
-    of set kwargs.
-    """
-
-    items: List[Any]
-
-
-class FxGraphHashDetails:
-    """
-    Object to capture all the details for a compiled FX graph relevant to computing
-    a safe and stable cache key.
-    """
-
-    # Excluded kwargs param that are not stable between runs
-    EXCLUDED_KWARGS = ["graph_id"]
-
-    def __init__(self, fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> None:
-        self.fx_args = fx_args
-
-        # Order kwargs so hashing is stable to changes in kwarg order.
-        self.fx_kwargs = {}
-        for k in sorted(fx_kwargs):
-            if k not in self.EXCLUDED_KWARGS:
-                if type(fx_kwargs[k]) is set:
-                    # Special case to handle set params. Python sets can't be
-                    # ordered, so sort the elements and store them in a proxy.
-                    self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k]))
-                else:
-                    self.fx_kwargs[k] = fx_kwargs[k]
-
-        # Also hash on various system info (including the triton compiler version), as
-        # well as the inductor configuration and code.
-        self.torch_version = torch.__version__
-        self.system_info = CacheBase.get_system()
-
-        self.inductor_config = config.save_config()  # type: ignore[attr-defined]
-        self.inductor_code_hash = get_inductor_code_hash()
-
-
-def compiled_fx_graph_hash(fx_args: List[Any], fx_kwargs: Dict[str, Any]) -> str:
-    """
-    Generate a unique hash of the FX graph for caching.
-    """
-    details = FxGraphHashDetails(fx_args, fx_kwargs)
-    serialized_data = FxGraphCachePickler.dumps(details)
-    return (
-        "f"
-        + base64.b32encode(hashlib.sha256(serialized_data).digest())[:51]
-        .decode("utf-8")
-        .lower()
-    )
-
-
-class FxGraphCache:
-    """
-    Supports caching and reusing compiled Fx graphs.
-    """
-
-    # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs
-    # in an in-memory cache after loading from disk.
-    @classmethod
-    def save_graph(cls, key: str, compiled_graph: CompiledFxGraph):
-        disk_compiled_graph = copy(compiled_graph)
-        # Important as compiled models are not pickleable
-        # TODO: Check status of PR #101651 as that might change the above statement
-        disk_compiled_graph.compiled_artifact = None
-        write(pickle.dumps(disk_compiled_graph), "cg", extra=key, hash_type="cg")
-
-    @classmethod
-    def load_graph(cls, cg_path: str) -> CompiledFxGraph:
-        with open(cg_path, "rb") as f:
-            return pickle.load(f)
-
-    @classmethod
-    def load(
-        cls,
-        compile_fx_fn: Callable[..., Any],
-        fx_args: List[Any],
-        fx_kwargs: Dict[str, Any],
-    ):
-        from filelock import FileLock
-
-        try:
-            key = compiled_fx_graph_hash(fx_args, fx_kwargs)
-        except DynamicShapeError:
-            # TODO(masnresral): support dynamic shapes
-            log.debug("fx graph cache skip due to dynamic shapes")
-            counters["inductor"]["fxgraph_cache_skip"] += 1
-            return compile_fx_fn(*fx_args, **fx_kwargs)
-
-        lock_dir = get_lock_dir()
-        lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
-        with lock:
-            _, _, cg_path = get_path(key, "cg")
-            if not os.path.exists(cg_path):
-                log.debug("fx graph cache miss for key %s", key)
-                counters["inductor"]["fxgraph_cache_miss"] += 1
-                compiled_graph: CompiledFxGraph = compile_fx_fn(*fx_args, **fx_kwargs)
-                cls.save_graph(key, compiled_graph)
-            else:
-                # Load required info from disk; recreation of compiled model will be on first run
-                log.debug("fx graph cache hit for key %s", key)
-                counters["inductor"]["fxgraph_cache_hit"] += 1
-                compiled_graph = cls.load_graph(cg_path)
-
-            return compiled_graph
-
-
-@dataclasses.dataclass
 class CompiledFxGraph:
     """Class holding a compiled FX graph"""
 
@@ -658,7 +396,6 @@
     device_idxs: Set[int] = field(default_factory=set)
     mutated_inputs: Set[str] = field(default_factory=set)
     mutated_input_idxs: Set[int] = field(default_factory=set)
-    constants: Dict[str, torch.Tensor] = field(default_factory=dict)
 
     _boxed_call: Optional[bool] = None
 
@@ -688,7 +425,6 @@
             compiled_graph.cache_key,
             compiled_graph.artifact_path,
             compiled_graph.cache_linemap,
-            compiled_graph.constants,
         ).call
 
     return compiled_graph.compiled_artifact(inputs)
@@ -1589,10 +1325,9 @@
         source_code: str,
         extra: str = "",
         linemap: Optional[List[Tuple[int, str]]] = None,
-        attrs: Optional[Dict[str, Any]] = None,
     ) -> ModuleType:
         key, path = write(source_code, "py", extra=extra)
-        return cls.load_by_key_path(key, path, linemap, attrs)
+        return cls.load_by_key_path(key, path, linemap)
 
     @classmethod
     def load_by_key_path(
@@ -1600,7 +1335,6 @@
         key: str,
         path: str,
         linemap: Optional[List[Tuple[int, str]]] = None,
-        attrs: Optional[Dict[str, Any]] = None,
     ) -> ModuleType:
         if linemap is None:
             linemap = []
@@ -1622,10 +1356,6 @@
                 # unzip into separate lines/nodes lists
                 cls.linemaps[path] = list(zip(*linemap))
 
-                if attrs is not None:
-                    for k, v in attrs.items():
-                        setattr(mod, k, v)
-
         return cls.cache[key]
 
     @classmethod
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index f25b9f5..8429b5a 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -4,7 +4,6 @@
 import itertools
 import logging
 import sys
-import time
 import warnings
 
 from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Union
@@ -23,7 +22,7 @@
 )
 from torch._dynamo.utils import detect_fake_mode
 from torch._functorch.aot_autograd import make_boxed_func
-from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
+from torch._inductor.codecache import code_hash, CompiledFxGraph
 
 from torch._inductor.debug import save_args_for_compile_fx_inner
 from torch._ops import OpOverload
@@ -324,8 +323,6 @@
         cudagraphs = BoxedBool(config.triton.cudagraphs)
 
     # Inputs to fx_codegen_and_compile
-    # Anything that affects codegen should go here, so if the signature
-    # of fx_codegen_and_compile changes, the list and dict should be updated accordingly
     graph_args = [gm, example_inputs]
     graph_kwargs = {
         "cudagraphs": cudagraphs,
@@ -340,18 +337,9 @@
         "extern_node_serializer": extern_node_serializer,
     }
 
-    start = time.time()
-
-    if config.fx_graph_cache:
-        compiled_graph: CompiledFxGraph = FxGraphCache.load(
-            fx_codegen_and_compile, graph_args, graph_kwargs
-        )
-    else:
-        compiled_graph = fx_codegen_and_compile(
-            *graph_args, **graph_kwargs  # type: ignore[arg-type]
-        )
-
-    log.debug("FX codegen and compilation took %.3fs", time.time() - start)
+    compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
+        *graph_args, **graph_kwargs  # type: ignore[arg-type]
+    )
 
     if aot_mode:
         return compiled_graph
@@ -559,7 +547,6 @@
                         )
                     else:
                         context.output_strides.append(None)
-
             compiled_fn = graph.compile_to_fn()
 
             if V.aot_compilation is True:
@@ -577,7 +564,6 @@
                 device_idxs=graph.device_idxs,
                 mutated_inputs=graph.mutated_inputs,
                 mutated_input_idxs=set(graph.mutated_input_idxs),
-                constants=graph.constants,
             )
     return compiled_graph
 
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index eba3029..b354009 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -15,9 +15,6 @@
 # Whether to enable printing the source code for each future
 verbose_progress = False
 
-# use fx aot graph codegen cache
-fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1"
-
 # use cpp wrapper instead of python wrapper
 cpp_wrapper = False
 
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 0d207ca..5e09e5c 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -639,13 +639,6 @@
                 e.__traceback__
             ) from None
 
-    @staticmethod
-    def can_inline_constant(t: torch.Tensor) -> bool:
-        """
-        True if this is a small constant attr that will be inlined.
-        """
-        return len(t.shape) == 1 and t.shape[0] <= 8
-
     def get_attr(self, target, args, kwargs):
         # this is a constant
         value = getattr(self.module, target)
@@ -656,7 +649,7 @@
         with no_dispatch():
             if value.shape == ():
                 return Constant(value.item(), value.dtype, value.device)
-            if self.can_inline_constant(value):
+            if len(value.shape) == 1 and value.shape[0] <= 8:
                 # tensor lowering has constant inlining logic
                 from .lowering import tensor
 
@@ -974,13 +967,14 @@
         code, linemap = self.codegen()
         linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
         key, path = PyCodeCache.write(code)
-        mod = PyCodeCache.load_by_key_path(
-            key, path, linemap=linemap, attrs=self.constants
-        )
+        mod = PyCodeCache.load_by_key_path(key, path, linemap=linemap)
         self.cache_key = key
         self.cache_path = path
         self.cache_linemap = linemap
 
+        for name, value in self.constants.items():
+            setattr(mod, name, value)
+
         # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
         # TODO. Revisit this once the logging API is more mature
         assert mod.__file__ is not None
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 3bfffb7..7aa09a7 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -106,9 +106,7 @@
     import_strs: Set[str] = set()
     for name, obj in globals.items():
         import_strs.add(_format_import_statement(name, obj, importer))
-    # Sort the imports so we have a stable import block that allows us to
-    # hash the graph module and get a consistent key for use in a cache.
-    return "\n".join(sorted(import_strs))
+    return "\n".join(import_strs)
 
 
 @compatibility(is_backward_compatible=True)