removed compile cache and static argnums (#85783)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85783
Approved by: https://github.com/wconstab
diff --git a/build_variables.bzl b/build_variables.bzl
index b7fc16e..b26cca8 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -898,7 +898,6 @@
     "torch/csrc/autograd/python_variable.cpp",
     "torch/csrc/autograd/python_variable_indexing.cpp",
     "torch/csrc/functorch/init.cpp",
-    "torch/csrc/functorch/CompileCache.cpp",
     "torch/csrc/jit/backends/backend_init.cpp",
     "torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp",
     "torch/csrc/jit/python/init.cpp",
diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py
index bcd4b34..08953b6 100644
--- a/functorch/_src/aot_autograd.py
+++ b/functorch/_src/aot_autograd.py
@@ -17,7 +17,6 @@
 from torch.nn.utils import stateless
 
 from functorch import make_fx
-from torch._C._functorch import CompileCache
 from functorch.experimental import functionalize
 from torch._dispatch.python import enable_python_dispatcher
 from . import config
@@ -521,14 +520,6 @@
             return aot_dispatch_base(flat_fn, fake_flat_tensor_args, aot_config)
 
 
-class _CompileCache(CompileCache):
-    pass
-
-
-# using a C++-based pytree reduces the overhead by about 50%
-compile_cache = None
-
-
 # Inspired by autodidax (thanks!)
 class PytreeThunk:
     spec = None
@@ -555,53 +546,6 @@
             return x
         return pytree.tree_unflatten(x, self.spec)
 
-
-def filter_tensor_and_static_args(args, static_argnums):
-    """
-    Separate out the tensor and static args. Also, for the static args, store
-    the hash.
-    """
-    tensor_args = []
-    static_args = []
-    static_args_hashed = []
-    for idx, arg in enumerate(args):
-        if idx not in static_argnums:
-            tensor_args.append(arg)
-        else:
-            static_args.append(arg)
-            static_args_hashed.append(arg.__hash__())
-    return tensor_args, static_args, static_args_hashed
-
-
-def rearrange(tensor_args, static_args, static_argnums):
-    """
-    Generate the args as per the original spec. static_argnums is sorted.
-    """
-    tensor_index = 0
-    static_index = 0
-    index = 0
-    args = []
-    assert len(static_args) == len(static_argnums)
-    while tensor_index < len(tensor_args) and static_index < len(static_args):
-        if index == static_argnums[static_index]:
-            args.append(static_args[static_index])
-            static_index += 1
-        else:
-            args.append(tensor_args[tensor_index])
-            tensor_index += 1
-        index += 1
-
-    while tensor_index < len(tensor_args):
-        args.append(tensor_args[tensor_index])
-        tensor_index += 1
-
-    while static_index < len(static_args):
-        args.append(static_args[static_index])
-        static_index += 1
-
-    return args
-
-
 KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
 
 
@@ -611,8 +555,8 @@
     bw_compiler: Optional[Callable] = None,
     partition_fn: Callable = default_partition,
     decompositions: Optional[Dict] = None,
-    hasher_type: str = "StaticShapeHasher",
-    static_argnums: Optional[Tuple[int]] = None,
+    hasher_type=None,  # deprecated
+    static_argnums: Optional[Tuple[int]] = None,  # deprecated
 ) -> Callable:
     """
     Traces the forward and backward graph of :attr:`fn` using torch dispatch
@@ -627,14 +571,7 @@
     of core or simpler operators supported by the backend compilers.
 
     :func:`aot_function` uses a compilation cache, based on input tensor
-    properties, to detect when there is a need of recompilation. By default, its
-    behavior is static, i.e., it recompiles if shape of any input tensor
-    changes.
-
-    :attr:`static_argnums` allows user to mark the arguments of the original
-    :attr:`fn` as static. This is useful when an argument is a non-tensor, e.g.,
-    ``int`` or ``bool``. A change in the actual value of static arg causes
-    recompilation.
+    properties, to detect when there is a need of recompilation.
 
     .. warning::
         This API is experimental and likely to change.
@@ -654,8 +591,6 @@
             backward graphs.
         decompositions (Dict): A dictionary to define the decomposition of
             larger Aten ops into simpler or core Aten ops.
-        static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark
-            the arguments of the function as static.
 
     Returns:
         Returns a ``Callable`` that retains the eager behavior of the original
@@ -672,22 +607,10 @@
         >>> aot_fn = aot_function(fn, print_compile_fn)
         >>> x = torch.randn(4, 5, requires_grad=True)
         >>> aot_fn(x)
-
-    The static argnums are used to mark the non-tensor arguments as static. An
-    example is as follows where the dropout probability is as argument to the
-    original function.
-
-        >>> def fn(input, bias, residual, p: float):
-        >>>     a = torch.add(input, bias)
-        >>>     b = torch.nn.functional.dropout(a, p, training=True)
-        >>>     c = b + residual
-        >>>     return c
-        >>> aot_fn = aot_function(fn, print_compile_fn, static_argnums=(3,))
-
     """
-    global compile_cache
-    if compile_cache is None:
-        compile_cache = CompileCache()
+    if static_argnums is not None:
+        raise RuntimeError("static_argnums has been deprecated - manually wrap your function or use torchdynamo.")
+
     if bw_compiler is None:
         bw_compiler = fw_compiler
     aot_config = AOTConfig(
@@ -698,69 +621,26 @@
     )
     cached_res = None
 
-    fn_id = id(fn)
-    fw_compiler_id = id(fw_compiler)
-    bw_compiler_id = id(bw_compiler)
-
-    if isinstance(static_argnums, int):
-        static_argnums = [static_argnums]
-    elif static_argnums is not None and len(static_argnums) == 0:
-        static_argnums = None
-    elif static_argnums is not None:
-        static_argnums = list(static_argnums)
-        static_argnums.sort()
-
     @wraps(fn)
     def returned_function(*args, **kwargs):
-        global compile_cache
         nonlocal cached_res
-
-        # Separate out static args if static_argnums is present
-        tensor_args = args
-        static_args = []
-        # TODO - move the hashing part of static_args to C++.
-        static_args_hashed = []
-        if static_argnums is not None:
-            (
-                tensor_args,
-                static_args,
-                static_args_hashed,
-            ) = filter_tensor_and_static_args(args, static_argnums)
-
         # Now flatten the tensor args
-        flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs))
-
-        # Check if the fn is already compiled
-        num_tensor_args = len(flat_tensor_args)
-        flat_args_for_cache = flat_tensor_args + static_args_hashed
-        cached_res = compile_cache.at(
-            fn_id,
-            fw_compiler_id,
-            bw_compiler_id,
-            num_tensor_args,
-            hasher_type,
-            *flat_args_for_cache,
-        )
+        flat_args, _ = pytree.tree_flatten((args, kwargs))
 
         # Compile the function and save it in the cache
         if cached_res is None:
             # Save the args_spec for flat_tensor_args to unflatten while tracing
-            _, tensor_args_spec = pytree.tree_flatten((tensor_args, kwargs))
+            _, tensor_args_spec = pytree.tree_flatten((args, kwargs))
             out_spec = PytreeThunk()
 
-            def flat_fn(*flat_tensor_args):
+            def flat_fn(*flat_args):
                 # The input are flattened tensor args. Prepare the args in the
                 # order that original function expects. Add static args as well.
                 # They will appear as tensor constants in the traced graph.
-                nonlocal out_spec, static_args
-
-                tensor_args, kwargs = pytree.tree_unflatten(
-                    flat_tensor_args, tensor_args_spec
+                nonlocal out_spec
+                args, kwargs = pytree.tree_unflatten(
+                    flat_args, tensor_args_spec
                 )
-                if static_argnums is None:
-                    args = tensor_args
-                else:
-                    args = rearrange(tensor_args, static_args, static_argnums)
                 tree_out = fn(*args, **kwargs)
                 flat_out, spec = pytree.tree_flatten(tree_out)
                 for i in flat_out:
@@ -783,50 +663,18 @@
 
             compiled_fn = create_aot_dispatcher_function(
                 flat_fn,
-                flat_tensor_args,
+                flat_args,
                 aot_config,
             )
             cached_res = (compiled_fn, out_spec)
 
-            # Save the compiled_fn in the cache
-            compile_cache.insert(
-                fn_id,
-                fw_compiler_id,
-                bw_compiler_id,
-                num_tensor_args,
-                hasher_type,
-                cached_res,
-                *flat_args_for_cache,
-            )
-
         cached_fn, out_spec = cached_res
-        out = cached_fn(flat_tensor_args)
+        out = cached_fn(flat_args)
         return out_spec.unflatten(out)
 
     return returned_function
 
 
-def num_of_recompilations():
-    """
-    Returns the numbers of recompilations since the last time cache was cleared.
-    This is equivalent to the number of entries in the compilation cache.
-    """
-    global compile_cache
-    if compile_cache is None:
-        return 0
-    return compile_cache.size()
-
-
-def clear_compile_cache():
-    """
-    Clears the compilation cache.
-    """
-    global compile_cache
-    if compile_cache is not None:
-        compile_cache.clear()
-        compile_cache = None
-
-
 def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
     """
     Traces the forward and backward graph of :attr:`mod` using torch dispatch
@@ -921,8 +769,8 @@
         bw_compiler: Optional[Callable] = None,
         partition_fn: Callable = default_partition,
         decompositions: Optional[Dict] = None,
-        hasher_type: str = "StaticShapeHasher",
-        static_argnums: Optional[Tuple[int]] = None,
+        hasher_type=None,
+        static_argnums=None,
     ) -> Callable:
         assert static_argnums is None
         if bw_compiler is None:
diff --git a/functorch/_src/compilers.py b/functorch/_src/compilers.py
index 3dd8455..18deafa 100644
--- a/functorch/_src/compilers.py
+++ b/functorch/_src/compilers.py
@@ -193,7 +193,6 @@
         "fw_compiler": ts_compile,
         "bw_compiler": ts_compile,
         "partition_fn": min_cut_rematerialization_partition,
-        "hasher_type": "StaticShapeHasher",
         "decompositions": default_decompositions,
         "static_argnums": static_argnums,
     }
diff --git a/functorch/benchmarks/transformer_fusion_patterns/benchmark.py b/functorch/benchmarks/transformer_fusion_patterns/benchmark.py
index a6646e1..f799422 100644
--- a/functorch/benchmarks/transformer_fusion_patterns/benchmark.py
+++ b/functorch/benchmarks/transformer_fusion_patterns/benchmark.py
@@ -1,5 +1,5 @@
 import torch
-from functorch.compile import memory_efficient_fusion, clear_compile_cache
+from functorch.compile import memory_efficient_fusion
 import benchmark_helper
 
 
@@ -159,7 +159,6 @@
 
 for cl in [DropoutResBias, BiasReluDropout, DropoutResBiasScalar, BiasDropoutResLayerNorm, LayerNormSigmoid]:
     # Clear the compile cache
-    clear_compile_cache()
 
     # Get the function and inputs
     obj = cl()
diff --git a/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py b/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py
index b231806..26c6d7c 100644
--- a/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py
+++ b/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py
@@ -1,5 +1,5 @@
 import torch
-from functorch.compile import memory_efficient_pointwise_fusion, clear_compile_cache
+from functorch.compile import memory_efficient_pointwise_fusion
 import benchmark_helper
 
 # ALL comments regarding the patetrns
@@ -21,7 +21,6 @@
 
 fn = bias_gelu_dropout
 
-clear_compile_cache()
 
 # Set inputs
 device = "cuda"
diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py
index b489eb4..12549dc 100644
--- a/functorch/compile/__init__.py
+++ b/functorch/compile/__init__.py
@@ -5,8 +5,6 @@
     aot_module,
     compiled_function,
     compiled_module,
-    num_of_recompilations,
-    clear_compile_cache,
     aot_module_simplified,
     get_graph_being_compiled,
     get_aot_graph_name,
diff --git a/functorch/notebooks/aot_autograd_optimizations.ipynb b/functorch/notebooks/aot_autograd_optimizations.ipynb
index f12c2a7..beef4cb 100644
--- a/functorch/notebooks/aot_autograd_optimizations.ipynb
+++ b/functorch/notebooks/aot_autograd_optimizations.ipynb
@@ -127,10 +127,7 @@
     "\n",
     "# Run the aot_print_fn once to trigger the compilation and print the graphs\n",
     "res = aot_print_fn(a, b, c, d).sum().backward()\n",
-    "assert torch.allclose(ref, res)\n",
-    "\n",
-    "from functorch.compile import clear_compile_cache\n",
-    "clear_compile_cache()"
+    "assert torch.allclose(ref, res)"
    ]
   },
   {
diff --git a/functorch/test/common_utils.py b/functorch/test/common_utils.py
index 65b9dcb..91cba9e 100644
--- a/functorch/test/common_utils.py
+++ b/functorch/test/common_utils.py
@@ -351,7 +351,7 @@
         matching_opinfos = [o for o in all_opinfos
                             if o.name == decorate_meta.op_name and
                             o.variant_test_name == decorate_meta.variant_name]
-        assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {xfail}"
+        assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}"
         assert len(matching_opinfos) == 1, (
             "OpInfos should be uniquely determined by their (name, variant_name). "
             f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})"
diff --git a/functorch/test/test_aotdispatch.py b/functorch/test/test_aotdispatch.py
index 3cce940..bfcabd2 100644
--- a/functorch/test/test_aotdispatch.py
+++ b/functorch/test/test_aotdispatch.py
@@ -24,8 +24,8 @@
 from functorch.compile import (
     nnc_jit, compiled_function, compiled_module,
     min_cut_rematerialization_partition, aot_function, aot_module,
-    nop, num_of_recompilations, default_partition, default_decompositions,
-    memory_efficient_fusion, clear_compile_cache, get_aot_compilation_context
+    nop, default_partition, default_decompositions,
+    memory_efficient_fusion, get_aot_compilation_context
 )
 from torch._decomp import decomposition_table
 
@@ -61,9 +61,6 @@
 class AOTTestCase(TestCase):
     def setUp(self):
         super().setUp()
-        # NB: We cache on function id, which is unreliable
-        # Can fix by using weakrefs, but not sure if it matters
-        clear_compile_cache()
 
 class TestPythonKey(AOTTestCase):
     def test_make_fx(self, device):
@@ -292,18 +289,17 @@
             graph_size = len(fx_g.graph.nodes)
             return fx_g
 
-        start_recompilations = num_of_recompilations()
         f = aot_function(foo, nop, get_graph_size)
         with torch.set_grad_enabled(False):
             f(*inps)
         self.assertIsNone(graph_size)
 
+        f = aot_function(foo, nop, get_graph_size)
         with torch.set_grad_enabled(True):
             out = f(*inps)
             self.assertIsNone(graph_size)
             out.sum().backward()
             self.assertTrue(graph_size > 2)
-        self.assertEqual(num_of_recompilations() - start_recompilations, 2)
 
     def test_output_dict(self):
         def f(x):
@@ -366,6 +362,7 @@
 
         f = aot_function(f, compiler)
         out = f(torch.randn(5, requires_grad=True))
+        f = aot_function(f, compiler)
         f(torch.randn(5))
         out.sum().backward()
         self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])
@@ -731,7 +728,6 @@
         res = aot_mod(*inputs)
         res[0].sum().backward()
 
-
 only_for = ("cpu")
 instantiate_device_type_tests(
     TestPythonKey,
diff --git a/functorch/test/test_compile_cache.py b/functorch/test/test_compile_cache.py
deleted file mode 100644
index 2115e58..0000000
--- a/functorch/test/test_compile_cache.py
+++ /dev/null
@@ -1,686 +0,0 @@
-# Owner(s): ["module: functorch"]
-
-import torch
-
-import functorch
-from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
-import unittest
-
-from functorch.compile import aot_function, nop
-
-
-class TestCompileCache(TestCase):
-    def check(self, a, b, aot_fn, fn):
-        a_clone = a.clone().detach().requires_grad_(True)
-        b_clone = b.clone().detach().requires_grad_(True)
-        ref = fn(a, b)
-        ref.sum().backward()
-
-        res = aot_fn(a_clone, b_clone)
-        res.sum().backward()
-        assert torch.allclose(res, ref)
-        assert torch.allclose(a.grad, a_clone.grad)
-        assert torch.allclose(b.grad, b_clone.grad)
-
-    def test_recompilation_on_broadcast(self):
-        def fn(x, bias):
-            return x + bias
-
-        for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
-            functorch.compile.clear_compile_cache()
-            start_num_recomps = functorch.compile.num_of_recompilations()
-            aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
-
-            a = torch.randn(10, 20, requires_grad=True)
-            b = torch.randn(20, requires_grad=True)
-            self.check(a, b, aot_autograd_fn, fn)
-
-            a = torch.randn(10, 20, requires_grad=True)
-            b = torch.randn(10, 20, requires_grad=True)
-            self.check(a, b, aot_autograd_fn, fn)
-
-            end_num_recomps = functorch.compile.num_of_recompilations()
-
-            total_recomps = end_num_recomps - start_num_recomps
-            assert total_recomps == 2
-
-    def test_compilation_for_dynamic_shape(self):
-        def fn(x, bias):
-            return x + bias
-
-        for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
-            functorch.compile.clear_compile_cache()
-            start_num_recomps = functorch.compile.num_of_recompilations()
-            aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
-
-            for s in range(10, 20):
-                a = torch.randn(s, requires_grad=True)
-                b = torch.randn(s, requires_grad=True)
-                self.check(a, b, aot_autograd_fn, fn)
-
-            for s in range(10, 20):
-                a = torch.randn(s, requires_grad=True)
-                b = torch.randn(s, requires_grad=True)
-                self.check(a, b, aot_autograd_fn, fn)
-
-            end_num_recomps = functorch.compile.num_of_recompilations()
-
-            total_recomps = end_num_recomps - start_num_recomps
-            if hasher_type == "DynamicShapeHasher":
-                assert total_recomps == 1
-            elif hasher_type == "StaticShapeHasher":
-                assert total_recomps == 10
-
-            for s in range(10, 20):
-                a = torch.randn(s, s, requires_grad=True)
-                b = torch.randn(s, s, requires_grad=True)
-                self.check(a, b, aot_autograd_fn, fn)
-
-            end_num_recomps = functorch.compile.num_of_recompilations()
-
-            total_recomps = end_num_recomps - start_num_recomps
-            if hasher_type == "DynamicShapeHasher":
-                assert total_recomps == 2
-            elif hasher_type == "StaticShapeHasher":
-                assert total_recomps == 20
-
-    def test_global_cache_no_recompilations(self):
-        def f(x, bias):
-            return x + bias
-
-        def g(x, bias):
-            return aot_function(f, nop, nop, hasher_type="DynamicShapeHasher")(x, bias)
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-        for _ in range(10):
-            a = torch.randn(10, 20, requires_grad=True)
-            b = torch.randn(10, 20, requires_grad=True)
-            self.check(a, b, g, f)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 1
-
-    def test_multiple_functions(self):
-        def f(x, bias):
-            return x + bias
-
-        def g(x, y):
-            return x * y
-
-        for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
-            functorch.compile.clear_compile_cache()
-            aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
-            aot_autograd_g = aot_function(g, nop, nop, hasher_type=hasher_type)
-
-            start_num_recomps = functorch.compile.num_of_recompilations()
-            a = torch.randn(10, requires_grad=True)
-            b = torch.randn(10, requires_grad=True)
-            self.check(a, b, aot_autograd_f, f)
-
-            a = torch.randn(10, requires_grad=True)
-            b = torch.randn(10, requires_grad=True)
-            self.check(a, b, aot_autograd_g, g)
-
-            end_num_recomps = functorch.compile.num_of_recompilations()
-            total_recomps = end_num_recomps - start_num_recomps
-            assert total_recomps == 2
-
-            # Force recompilation for function f and check num of recompilations again
-            a = torch.randn(10, 20, requires_grad=True)
-            b = torch.randn(10, 20, requires_grad=True)
-            self.check(a, b, aot_autograd_f, f)
-
-            end_num_recomps = functorch.compile.num_of_recompilations()
-            total_recomps = end_num_recomps - start_num_recomps
-            assert total_recomps == 3
-
-    def test_high_number_of_args(self):
-        def f(*args):
-            res = args[0]
-            for arg in args:
-                res = res * arg
-            return res
-
-        def check(args, aot_autograd_fn, fn):
-            args_clone = [arg.clone().detach().requires_grad_(True) for arg in args]
-            ref = fn(*args)
-            ref.sum().backward()
-
-            res = aot_autograd_fn(*args_clone)
-            res.sum().backward()
-            assert torch.allclose(res, ref)
-            for (arg, arg_clone) in zip(args, args_clone):
-                assert torch.allclose(arg.grad, arg_clone.grad)
-
-        for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
-            functorch.compile.clear_compile_cache()
-
-            aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
-
-            args = [torch.randn(10, requires_grad=True) for _ in range(100)]
-            check(args, aot_autograd_f, f)
-
-    def test_multiple_compiler(self):
-        def fn(x, bias):
-            return x + bias
-
-        def nop_duplicate(fx_g, _):
-            return fx_g
-
-        for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
-            functorch.compile.clear_compile_cache()
-            start_num_recomps = functorch.compile.num_of_recompilations()
-            nop_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
-            nop_duplicate_fn = aot_function(
-                fn, nop_duplicate, nop_duplicate, hasher_type=hasher_type
-            )
-
-            a = torch.randn(10, 20, requires_grad=True)
-            b = torch.randn(20, requires_grad=True)
-            nop_fn(a, b)
-            nop_duplicate_fn(a, b)
-
-            end_num_recomps = functorch.compile.num_of_recompilations()
-
-            total_recomps = end_num_recomps - start_num_recomps
-            assert total_recomps == 2
-
-
-@unittest.skipIf(IS_WINDOWS, 'test broken on windows')
-class TestCompileCacheStaticArgs(TestCase):
-    def check(self, a, b, aot_autograd_fn, fn):
-        a_clone = a.clone().detach().requires_grad_(True)
-        ref = fn(a, b)
-        ref.sum().backward()
-
-        res = aot_autograd_fn(a_clone, b)
-        res.sum().backward()
-        assert torch.allclose(res, ref)
-        assert torch.allclose(a.grad, a_clone.grad)
-
-    def test_failure(self):
-        # Test that not setting up static_argnums should raise exception
-        def fn(x, p):
-            return x * p
-
-        aot_autograd_f = aot_function(fn, nop, nop)
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = 2
-        try:
-            # Since b is not marked as static, it should raise exception
-            aot_autograd_f(a, b)
-            raise AssertionError()
-        except RuntimeError:
-            pass
-
-    def test_simple(self):
-        def fn(x, static_arg):
-            return x * static_arg
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = 2
-        self.check(a, b, aot_autograd_f, fn)
-
-        # Same type of args, so no recompilation
-        a = torch.randn(2, 2, requires_grad=True)
-        b = 2
-        self.check(a, b, aot_autograd_f, fn)
-
-        # Trigger recompilation
-        a = torch.randn(2, 2, requires_grad=True)
-        b = 3
-        self.check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_static_arg_before_tensor_arg(self):
-        def fn(static_arg, x):
-            return static_arg - x
-
-        def check(a, b, aot_autograd_fn, fn):
-            b_clone = b.clone().detach().requires_grad_(True)
-            ref = fn(a, b)
-            ref.sum().backward()
-
-            res = aot_autograd_fn(a, b_clone)
-            res.sum().backward()
-            assert torch.allclose(res, ref)
-            assert torch.allclose(b.grad, b_clone.grad)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=0)
-
-        a = 2
-        b = torch.randn(2, 2, requires_grad=True)
-        check(a, b, aot_autograd_f, fn)
-
-        a = 3
-        b = torch.randn(2, 2, requires_grad=True)
-        check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_interleaved_static_args(self):
-        def fn(static_arg1, x, static_arg2):
-            return static_arg1 - x - static_arg2
-
-        def check(a, b, c, aot_autograd_fn, fn):
-            b_clone = b.clone().detach().requires_grad_(True)
-            ref = fn(a, b, c)
-            ref.sum().backward()
-
-            res = aot_autograd_fn(a, b_clone, c)
-            res.sum().backward()
-            assert torch.allclose(res, ref)
-            assert torch.allclose(b.grad, b_clone.grad)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, 2))
-
-        a = 2
-        b = torch.randn(2, 2, requires_grad=True)
-        c = 0.1
-        check(a, b, c, aot_autograd_f, fn)
-
-        a = 3
-        b = torch.randn(2, 2, requires_grad=True)
-        c = 0.1
-        check(a, b, c, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_dropout(self):
-        def fn(x, prob):
-            return torch.nn.functional.dropout(x, p=prob)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = 0.3
-        aot_autograd_f(a, b)
-
-        # Setting the prob to 0. This should cause recompilation.
-        a = torch.randn(2, 2, requires_grad=True)
-        b = 0
-        self.check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_if_condition(self):
-        def fn(x, state: bool):
-            if state:
-                return torch.sin(x)
-            else:
-                return torch.cos(x)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = True
-        self.check(a, b, aot_autograd_f, fn)
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = True
-        self.check(a, b, aot_autograd_f, fn)
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = False
-        self.check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_custom(self):
-        class Record:
-            def __init__(self, name, multiplier):
-                self.name = name
-                self.multiplier = multiplier
-
-            def __eq__(self, other):
-                return self.name == other.name and self.multiplier == other.multiplier
-
-            def __hash__(self):
-                return hash((self.name, self.multiplier))
-
-        def fn(x, record):
-            return x * record.multiplier
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = Record("Foo", 0.5)
-        self.check(a, b, aot_autograd_f, fn)
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = Record("Bar", 10.2)
-        self.check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_tuple(self):
-        def fn(a_tuple, static_arg):
-            return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
-
-        def check(a_tuple, b, aot_autograd_fn, fn):
-            a0 = a_tuple[0]
-            a1 = a_tuple[1]
-
-            a0_clone = a0.clone().detach().requires_grad_(True)
-            a1_clone = a1.clone().detach().requires_grad_(True)
-            ref = fn(a, b)
-            ref.sum().backward()
-
-            res = aot_autograd_fn((a0_clone, a1_clone), b)
-            res.sum().backward()
-            assert torch.allclose(res, ref)
-            assert torch.allclose(a0.grad, a0_clone.grad)
-            assert torch.allclose(a1.grad, a1_clone.grad)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
-
-        a = (
-            torch.randn(2, 2, requires_grad=True),
-            torch.randn(2, 2, requires_grad=True),
-        )
-        b = 0.1
-        check(a, b, aot_autograd_f, fn)
-
-        a = (
-            torch.randn(2, 2, requires_grad=True),
-            torch.randn(2, 2, requires_grad=True),
-        )
-        b = 1
-        check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_tuple_with_first_arg_as_static(self):
-        def fn(static_arg, a_tuple):
-            return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
-
-        def check(a, b_tuple, aot_autograd_fn, fn):
-            b0 = b_tuple[0]
-            b1 = b_tuple[1]
-
-            b0_clone = b0.clone().detach().requires_grad_(True)
-            b1_clone = b1.clone().detach().requires_grad_(True)
-            ref = fn(a, b_tuple)
-            ref.sum().backward()
-
-            res = aot_autograd_fn(a, (b0_clone, b1_clone))
-            res.sum().backward()
-            assert torch.allclose(res, ref)
-            assert torch.allclose(b0.grad, b0_clone.grad)
-            assert torch.allclose(b1.grad, b1_clone.grad)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
-
-        a = 0.1
-        b = (
-            torch.randn(2, 2, requires_grad=True),
-            torch.randn(2, 2, requires_grad=True),
-        )
-        check(a, b, aot_autograd_f, fn)
-
-        a = 1
-        b = (
-            torch.randn(2, 2, requires_grad=True),
-            torch.randn(2, 2, requires_grad=True),
-        )
-        check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_dict(self):
-        def fn(a_dict, static_arg):
-            return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
-
-        def check(a_dict, b, aot_autograd_fn, fn):
-
-            a0 = a_dict["foo"]
-            a1 = a_dict["bar"]
-
-            a0_clone = a0.clone().detach().requires_grad_(True)
-            a1_clone = a1.clone().detach().requires_grad_(True)
-            ref = fn(a_dict, b)
-            ref.sum().backward()
-
-            a_clone = {}
-            a_clone["foo"] = a0_clone
-            a_clone["bar"] = a1_clone
-            res = aot_autograd_fn(a_clone, b)
-            res.sum().backward()
-            assert torch.allclose(res, ref)
-            assert torch.allclose(a0.grad, a0_clone.grad)
-            assert torch.allclose(a1.grad, a1_clone.grad)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
-
-        a = {}
-        a["foo"] = torch.zeros(2, 2, requires_grad=True)
-        a["bar"] = torch.ones(2, 2, requires_grad=True)
-        b = 0
-        check(a, b, aot_autograd_f, fn)
-
-        a = {}
-        a["foo"] = torch.randn(2, 2, requires_grad=True)
-        a["bar"] = torch.randn(2, 2, requires_grad=True)
-        b = 0.2
-        check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_dict_with_static_arg_before_dict(self):
-        def fn(static_arg, a_dict):
-            return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
-
-        def check(a, b_dict, aot_autograd_fn, fn):
-
-            ref = fn(a, b_dict)
-            res = aot_autograd_fn(a, b_dict)
-            assert torch.allclose(res, ref)
-
-            b0 = b_dict["foo"]
-            b1 = b_dict["bar"]
-
-            b0_clone = b0.clone().detach().requires_grad_(True)
-            b1_clone = b1.clone().detach().requires_grad_(True)
-            ref.sum().backward()
-
-            b_clone = {}
-            b_clone["foo"] = b0_clone
-            b_clone["bar"] = b1_clone
-            res = aot_autograd_fn(a, b_clone)
-            res.sum().backward()
-            assert torch.allclose(res, ref)
-            assert torch.allclose(b0.grad, b0_clone.grad)
-            assert torch.allclose(b1.grad, b1_clone.grad)
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
-
-        a = 0.1
-        b = {}
-        b["foo"] = torch.randn(2, 2, requires_grad=True)
-        b["bar"] = torch.randn(2, 2, requires_grad=True)
-        check(a, b, aot_autograd_f, fn)
-
-        a = 0.2
-        b = {}
-        b["foo"] = torch.randn(2, 2, requires_grad=True)
-        b["bar"] = torch.randn(2, 2, requires_grad=True)
-        check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_tuple_static_args(self):
-        def fn(x, tuple_static_arg):
-            return x * tuple_static_arg[0] * tuple_static_arg[1]
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
-
-        a = torch.randn(2, 2, requires_grad=True)
-        b = (2, 3)
-        self.check(a, b, aot_autograd_f, fn)
-
-        # Same type of args, so no recompilation
-        a = torch.randn(2, 2, requires_grad=True)
-        b = (2, 3)
-        self.check(a, b, aot_autograd_f, fn)
-
-        # Trigger recompilation
-        a = torch.randn(2, 2, requires_grad=True)
-        b = (3, 4)
-        self.check(a, b, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 2
-
-    def test_arg_none(self):
-        def check(a, b, c, aot_autograd_fn, fn):
-            def cloner(x):
-                if x is not None:
-                    return x.clone().detach().requires_grad_(True)
-                return None
-
-            def check_grad(x, x_clone):
-                if x is not None:
-                    return torch.allclose(x.grad, x_clone.grad)
-                return True
-
-            ref = fn(a, b, c)
-            res = aot_autograd_fn(a, b, c)
-            assert torch.allclose(res, ref)
-
-            a_clone = cloner(a)
-            b_clone = cloner(b)
-            c_clone = cloner(c)
-            ref.sum().backward()
-            res = aot_autograd_fn(a_clone, b_clone, c_clone)
-            res.sum().backward()
-
-            check_grad(a, a_clone)
-            check_grad(b, b_clone)
-            check_grad(c, c_clone)
-
-        def fn(a, b, c):
-            if a is None and b is None:
-                return c
-            elif a is None and c is None:
-                return b
-            elif b is None and c is None:
-                return a
-            elif a is None:
-                return b + c
-            elif b is None:
-                return a + c
-            elif c is None:
-                return a + b
-            return a + b + c
-
-        functorch.compile.clear_compile_cache()
-
-        start_num_recomps = functorch.compile.num_of_recompilations()
-
-        aot_autograd_f = aot_function(fn, nop, nop)
-
-        t1 = torch.randn(2, 2, requires_grad=True)
-        check(t1, None, None, aot_autograd_f, fn)
-        check(None, t1, None, aot_autograd_f, fn)
-        check(None, None, t1, aot_autograd_f, fn)
-
-        t2 = torch.randn(2, 2, requires_grad=True)
-        check(t1, t2, None, aot_autograd_f, fn)
-        check(t1, None, t2, aot_autograd_f, fn)
-        check(None, t1, t2, aot_autograd_f, fn)
-
-        t3 = torch.randn(2, 2, requires_grad=True)
-        check(t1, t2, t3, aot_autograd_f, fn)
-
-        # Same type of args, so no recompilation
-        check(t1, t2, None, aot_autograd_f, fn)
-
-        end_num_recomps = functorch.compile.num_of_recompilations()
-
-        total_recomps = end_num_recomps - start_num_recomps
-        assert total_recomps == 7
-
-
-if __name__ == "__main__":
-    run_tests()
diff --git a/functorch/test/test_functionalize.py b/functorch/test/test_functionalize.py
index 0ae8d5c..a04bd49 100644
--- a/functorch/test/test_functionalize.py
+++ b/functorch/test/test_functionalize.py
@@ -4,7 +4,6 @@
 from unittest.mock import patch
 import functools
 from torch.testing._internal.common_utils import run_tests
-import test_compile_cache
 import test_aotdispatch
 
 
@@ -38,8 +37,6 @@
     return FunctionalizeTest
 
 
-FunctionalizeTestCompileCache = make_functionalize_test(test_compile_cache.TestCompileCache)
-FunctionalizeTestCompileCacheStaticArgs = make_functionalize_test(test_compile_cache.TestCompileCacheStaticArgs)
 FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_aotdispatch.TestAOTAutograd)
 FunctionalizeTestPythonKeyPartitioning = make_functionalize_test(test_aotdispatch.TestPartitioning)
 
diff --git a/torch/csrc/functorch/CompileCache.cpp b/torch/csrc/functorch/CompileCache.cpp
deleted file mode 100644
index 0366e38..0000000
--- a/torch/csrc/functorch/CompileCache.cpp
+++ /dev/null
@@ -1,345 +0,0 @@
-// Copyright (c) Facebook, Inc. and its affiliates.
-// All rights reserved.
-//
-// This source code is licensed under the BSD-style license found in the
-// LICENSE file in the root directory of this source tree.
-
-///
-/// This design stemmed of from the PointwiseOperatorCompileCache with the
-/// purpose of making it more generic for AOTAutograd. This is Compile Cache
-/// allowing different types of hashing functions, and is agnostic of the
-/// compiler.
-///
-#include <torch/csrc/autograd/custom_function.h>
-#include <torch/csrc/functorch/CompileCache.h>
-#include <torch/csrc/jit/python/pybind_utils.h>
-#include <torch/csrc/jit/tensorexpr/codegen.h>
-#include <torch/csrc/utils/pybind.h>
-#include <torch/csrc/utils/python_numbers.h>
-
-using namespace torch::jit::tensorexpr;
-
-namespace {
-/// Record of thread-local state that changes operator behavior.
-struct LocalState {
-  c10::impl::LocalDispatchKeySet dispatchModifier;
-  bool gradModeEnabled;
-
-  at::DispatchKeySet apply(at::DispatchKeySet ks) const {
-    return (ks | dispatchModifier.included_) - dispatchModifier.excluded_;
-  }
-
-  LocalState()
-      : dispatchModifier(c10::impl::tls_local_dispatch_key_set()),
-        gradModeEnabled(at::GradMode::is_enabled()) {}
-};
-
-/// Helper to pack tensor (dtype, requires grad) into an 8-bit key.
-static uint8_t packFlags(const LocalState& state, const at::Tensor& v) {
-  static_assert(
-      static_cast<int>(at::ScalarType::NumOptions) < 128, "overflow possible");
-  at::ScalarType dtype = v.dtype().toScalarType();
-  bool requires_grad = state.gradModeEnabled && v.requires_grad();
-  return static_cast<uint8_t>(requires_grad) |
-      (static_cast<uint8_t>(dtype) << 1);
-}
-
-using hash_key_t = std::vector<int64_t>;
-/// Per-tensor cache specialization key targetting dynamic shapes. Records
-/// dtype, dispatch options, aliasing, and per-dim contiguity/broadcasting
-/// information.
-
-enum DimFlags {
-  /// A leading dimension implicitly added by broadcasting.
-  SIZE_MISSING = 1 << 0,
-
-  /// Size == 1.
-  SIZE_ONE = 1 << 1,
-
-  /// Size > 1.
-  SIZE_OTHER = 1 << 2,
-
-  /// Stride == 0; broadcasting.
-  STRIDE_ZERO = 1 << 3,
-
-  /// Stride == 1; packed contiguously in memory.
-  STRIDE_ONE = 1 << 4,
-
-  /// Stride = Stride[i + 1] * Size[i + 1].
-  /// Used to collapse dimensions.
-  STRIDE_CONTIGUOUS = 1 << 5,
-
-  /// Stride = Stride[i - 1] * Size[i - 1].
-  /// Used to collapse dimensions in the other direction.
-  STRIDE_TRANSPOSED_CONTIGUOUS = 1 << 6, // stride[i-1] * sizes[i-1]
-
-  /// Stride must be provided as an argument.
-  STRIDE_AS_ARG = 1 << 7,
-};
-
-/// Unique hasher id values to uniquely identify the type of hash. NONE_HASH is
-/// used when a tensor is undefined.
-enum HasherFlags {
-  NONE_HASH,
-  STATIC_HASH,
-  DYNAMIC_HASH,
-};
-
-std::vector<int> genDimFlags(c10::IntArrayRef sizes, c10::IntArrayRef strides) {
-  // Pack all the properties for each dimension into a uint8.
-  int nDims = sizes.size();
-  std::vector<int> dimflags(nDims);
-  for (int64_t dim = 0; dim < nDims; ++dim) {
-    uint8_t flag =
-        (sizes[dim] == 0 ? SIZE_MISSING
-                         : (sizes[dim] == 1 ? SIZE_ONE : SIZE_OTHER));
-    if (strides[dim] == 0) {
-      flag |= STRIDE_ZERO;
-    } else if (strides[dim] == 1) {
-      flag |= STRIDE_ONE;
-    } else if (
-        dim + 1 < (int64_t)sizes.size() &&
-        strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
-      flag |= STRIDE_CONTIGUOUS;
-    } else if (
-        dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
-        (dimflags[dim - 1] & STRIDE_CONTIGUOUS) == 0) {
-      flag |= STRIDE_TRANSPOSED_CONTIGUOUS;
-    } else {
-      flag |= STRIDE_AS_ARG;
-    }
-    dimflags[dim] = flag;
-  }
-  return dimflags;
-}
-
-hash_key_t dynamic_hasher(const LocalState& state, const at::Tensor& v) {
-  hash_key_t hash = {
-      DYNAMIC_HASH,
-      static_cast<int>(packFlags(state, v)),
-      static_cast<int>(state.apply(v.key_set()).raw_repr()),
-      static_cast<int>(v.ndimension())};
-  auto dimFlags = genDimFlags(v.sizes(), v.strides());
-  hash.insert(hash.end(), dimFlags.begin(), dimFlags.end());
-  return hash;
-}
-
-/// Per-tensor cache specialization key targetting static shapes. Recordsdtype,
-/// dispatch options, aliasing, and full shapes and strides.
-hash_key_t static_hasher(const LocalState& state, const at::Tensor& v) {
-  hash_key_t hash = {
-      STATIC_HASH,
-      static_cast<int>(packFlags(state, v)),
-      static_cast<int>(state.apply(v.key_set()).raw_repr()),
-      static_cast<int>(v.ndimension())};
-  hash.insert(hash.end(), v.sizes().begin(), v.sizes().end());
-  hash.insert(hash.end(), v.strides().begin(), v.strides().end());
-  return hash;
-}
-
-/// ArgCompileCache is a templated class allowing plugging of different types of
-/// Hasher/Specialization Keys.
-struct CompileCache {
- public:
-  CompileCache() = default;
-  ~CompileCache() = default;
-
-  /// Array defining groups of aliased tensors.
-
-  /// Cache type mapping specialization keys to compiled kernels.
-  class vector_hasher {
-   public:
-    std::size_t operator()(hash_key_t const& vec) const {
-      std::size_t seed = vec.size();
-      for (auto& i : vec) {
-        seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
-      }
-      return seed;
-    }
-  };
-  using Cache = std::unordered_map<hash_key_t, py::object, vector_hasher>;
-
-  /// Compute the set of specialization keys based on the inputs to
-  /// the kernel.
-  hash_key_t computeCacheKey(
-      PyObject* args,
-      const std::vector<at::Tensor>& tensorArgs,
-      int numTensorArgs,
-      const std::string& hasherType,
-      int64_t id,
-      int64_t fw_compiler_id,
-      int64_t bw_compiler_id) {
-    LocalState state;
-    hash_key_t cacheKey;
-    for (int i = 0; i < numTensorArgs; ++i) {
-      if (tensorArgs[i].defined()) {
-        // Only hash the tensor when its defined.
-        if (hasherType == "StaticShapeHasher") {
-          auto res = static_hasher(state, tensorArgs[i]);
-          cacheKey.insert(cacheKey.end(), res.begin(), res.end());
-        } else if (hasherType == "DynamicShapeHasher") {
-          auto res = dynamic_hasher(state, tensorArgs[i]);
-          cacheKey.insert(cacheKey.end(), res.begin(), res.end());
-        }
-      } else {
-        // Add a value to the cacheKey to indicate a None tensor.
-        cacheKey.push_back(NONE_HASH);
-      }
-    }
-    cacheKey.push_back(id);
-    cacheKey.push_back(fw_compiler_id);
-    cacheKey.push_back(bw_compiler_id);
-    cacheKey.push_back(numTensorArgs);
-
-    // Cache the non-tensor args. Currently, all the non-tensor args are cached.
-    for (int i = numTensorArgs; i < PyTuple_Size(args); i++) {
-      PyObject* arg = PyTuple_GET_ITEM(args, i);
-      assert(PyLong_Check(arg));
-      cacheKey.push_back(THPUtils_unpackLong(arg));
-    }
-    return cacheKey;
-  }
-
-  std::vector<at::Tensor> parsePythonArgs(int numTensorArgs, PyObject* args) {
-    // Convert to Tensor Args
-    std::vector<at::Tensor> tensorArgs(numTensorArgs);
-    for (int i = 0; i < numTensorArgs; ++i) {
-      PyObject* arg = PyTuple_GET_ITEM(args, i);
-      if (arg == Py_None) {
-        // If an input tensor is None, add it as an undefined tensor.
-        tensorArgs[i] = at::Tensor();
-      } else if (!THPVariable_Check(arg)) {
-        // Fail if its a non-tensor arg. It should be marked static.
-        std::string dtype = Py_TYPE(arg)->tp_name;
-        std::string index = std::to_string(i);
-        throw std::runtime_error(
-            "Found an argument of type " + dtype + " at index " + index +
-            ". Non-tensor arguments must be marked static."
-            " Please set the static_argnums correctly to "
-            "mark the argument at index " +
-            index + " static.");
-      } else {
-        tensorArgs[i] = THPVariable_Unpack(arg);
-      }
-    }
-    return tensorArgs;
-  }
-
-  /// Check if the function has already been compiled.
-  py::object at(
-      int64_t id,
-      int64_t fw_compiler_id,
-      int64_t bw_compiler_id,
-      int numTensorArgs,
-      const std::string& hasherType,
-      PyObject* args) {
-    std::vector<at::Tensor> tensorArgs = parsePythonArgs(numTensorArgs, args);
-    hash_key_t cacheKey = computeCacheKey(
-        args,
-        tensorArgs,
-        numTensorArgs,
-        hasherType,
-        id,
-        fw_compiler_id,
-        bw_compiler_id);
-
-    auto item = cache_.find(cacheKey); // protected by GIL
-
-    if (C10_LIKELY(item != cache_.end())) {
-      return item->second;
-    }
-    return py::none();
-  }
-
-  /// Insert a new compiled functions for new tensor properties.
-  void insert(
-      int64_t id,
-      int64_t fw_compiler_id,
-      int64_t bw_compiler_id,
-      int numTensorArgs,
-      const std::string& hasherType,
-      const py::object& compileFn,
-      PyObject* args) {
-    std::vector<at::Tensor> tensorArgs = parsePythonArgs(numTensorArgs, args);
-    LocalState state;
-    hash_key_t cacheKey = computeCacheKey(
-        args,
-        tensorArgs,
-        numTensorArgs,
-        hasherType,
-        id,
-        fw_compiler_id,
-        bw_compiler_id);
-    cache_.emplace(cacheKey, compileFn);
-  }
-
-  int64_t size() const {
-    return cache_.size();
-  }
-
-  /// Clear the cache.
-  void clear() {
-    cache_.clear();
-  }
-
- private:
-  /// Compilation cache holding key and the compiled function.
-  Cache cache_;
-};
-
-static CompileCache* createCompileCache() {
-  return new CompileCache();
-}
-
-} // namespace
-
-namespace torch {
-namespace functorch {
-
-void initCompileCacheBindings(PyObject* module) {
-  py::handle te(module);
-  py::class_<CompileCache>(te, "CompileCache")
-      .def(py::init(&createCompileCache))
-      .def(
-          "at",
-          [](CompileCache& self,
-             int64_t id,
-             int64_t fw_compiler_id,
-             int64_t bw_compiler_id,
-             int numTensorArgs,
-             const std::string& hasherType,
-             py::args args) {
-            return self.at(
-                id,
-                fw_compiler_id,
-                bw_compiler_id,
-                numTensorArgs,
-                hasherType,
-                args.ptr());
-          })
-      .def(
-          "insert",
-          [](CompileCache& self,
-             int64_t id,
-             int64_t fw_compiler_id,
-             int64_t bw_compiler_id,
-             int numTensorArgs,
-             const std::string& hasherType,
-             const py::object& compileFn,
-             py::args args,
-             py::kwargs kwargs) {
-            self.insert(
-                id,
-                fw_compiler_id,
-                bw_compiler_id,
-                numTensorArgs,
-                hasherType,
-                compileFn,
-                args.ptr());
-          })
-      .def("clear", [](CompileCache& self) { self.clear(); })
-      .def("size", [](CompileCache& self) { return self.size(); });
-}
-
-} // namespace functorch
-} // namespace torch
diff --git a/torch/csrc/functorch/CompileCache.h b/torch/csrc/functorch/CompileCache.h
deleted file mode 100644
index c205c18..0000000
--- a/torch/csrc/functorch/CompileCache.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright (c) Facebook, Inc. and its affiliates.
-// All rights reserved.
-//
-// This source code is licensed under the BSD-style license found in the
-// LICENSE file in the root directory of this source tree.
-#pragma once
-
-#include <torch/csrc/utils/pybind.h>
-
-namespace torch {
-namespace functorch {
-
-// CompileCache is the compilation cache used by the AOTAutograd frontend.
-// We're planning on deleting this in favor of torchdynamo's caching mechanism
-// (CompilerCache predates torchdynamo).
-
-/// Initialize python bindings for kernel compilation cache.
-TORCH_API void initCompileCacheBindings(PyObject* module);
-
-} // namespace functorch
-} // namespace torch
diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp
index 4089a93..45df25c 100644
--- a/torch/csrc/functorch/init.cpp
+++ b/torch/csrc/functorch/init.cpp
@@ -16,7 +16,6 @@
 #include <ATen/functorch/PlumbingHelper.h>
 #include <ATen/functorch/TensorWrapper.h>
 #include <c10/core/AutogradState.h>
-#include <torch/csrc/functorch/CompileCache.h>
 
 // This file contains functorch's Python bindings.
 
@@ -457,8 +456,6 @@
   m.def("is_functorch_wrapped_tensor", [](const Tensor& tensor) {
     return maybe_get_level(tensor) != -1;
   });
-
-  torch::functorch::initCompileCacheBindings(m.ptr());
 }
 
 } // namespace impl
diff --git a/torch/cuda/_dynamo_graphs.py b/torch/cuda/_dynamo_graphs.py
index 1d8ae67..56973e9 100644
--- a/torch/cuda/_dynamo_graphs.py
+++ b/torch/cuda/_dynamo_graphs.py
@@ -137,7 +137,6 @@
         # these are taken from memory_efficient_fusion()
         "fw_compiler": cudagraphs,
         "bw_compiler": cudagraphs,
-        "hasher_type": "StaticShapeHasher",
     }
 
     def _wrapped_bw_compiler(*args, **kwargs):