removed compile cache and static argnums (#85783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85783
Approved by: https://github.com/wconstab
diff --git a/build_variables.bzl b/build_variables.bzl
index b7fc16e..b26cca8 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -898,7 +898,6 @@
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_variable_indexing.cpp",
"torch/csrc/functorch/init.cpp",
- "torch/csrc/functorch/CompileCache.cpp",
"torch/csrc/jit/backends/backend_init.cpp",
"torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp",
"torch/csrc/jit/python/init.cpp",
diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py
index bcd4b34..08953b6 100644
--- a/functorch/_src/aot_autograd.py
+++ b/functorch/_src/aot_autograd.py
@@ -17,7 +17,6 @@
from torch.nn.utils import stateless
from functorch import make_fx
-from torch._C._functorch import CompileCache
from functorch.experimental import functionalize
from torch._dispatch.python import enable_python_dispatcher
from . import config
@@ -521,14 +520,6 @@
return aot_dispatch_base(flat_fn, fake_flat_tensor_args, aot_config)
-class _CompileCache(CompileCache):
- pass
-
-
-# using a C++-based pytree reduces the overhead by about 50%
-compile_cache = None
-
-
# Inspired by autodidax (thanks!)
class PytreeThunk:
spec = None
@@ -555,53 +546,6 @@
return x
return pytree.tree_unflatten(x, self.spec)
-
-def filter_tensor_and_static_args(args, static_argnums):
- """
- Separate out the tensor and static args. Also, for the static args, store
- the hash.
- """
- tensor_args = []
- static_args = []
- static_args_hashed = []
- for idx, arg in enumerate(args):
- if idx not in static_argnums:
- tensor_args.append(arg)
- else:
- static_args.append(arg)
- static_args_hashed.append(arg.__hash__())
- return tensor_args, static_args, static_args_hashed
-
-
-def rearrange(tensor_args, static_args, static_argnums):
- """
- Generate the args as per the original spec. static_argnums is sorted.
- """
- tensor_index = 0
- static_index = 0
- index = 0
- args = []
- assert len(static_args) == len(static_argnums)
- while tensor_index < len(tensor_args) and static_index < len(static_args):
- if index == static_argnums[static_index]:
- args.append(static_args[static_index])
- static_index += 1
- else:
- args.append(tensor_args[tensor_index])
- tensor_index += 1
- index += 1
-
- while tensor_index < len(tensor_args):
- args.append(tensor_args[tensor_index])
- tensor_index += 1
-
- while static_index < len(static_args):
- args.append(static_args[static_index])
- static_index += 1
-
- return args
-
-
KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
@@ -611,8 +555,8 @@
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
- hasher_type: str = "StaticShapeHasher",
- static_argnums: Optional[Tuple[int]] = None,
+ hasher_type=None, # deprecated
+ static_argnums: Optional[Tuple[int]] = None, # deprecated
) -> Callable:
"""
Traces the forward and backward graph of :attr:`fn` using torch dispatch
@@ -627,14 +571,7 @@
of core or simpler operators supported by the backend compilers.
:func:`aot_function` uses a compilation cache, based on input tensor
- properties, to detect when there is a need of recompilation. By default, its
- behavior is static, i.e., it recompiles if shape of any input tensor
- changes.
-
- :attr:`static_argnums` allows user to mark the arguments of the original
- :attr:`fn` as static. This is useful when an argument is a non-tensor, e.g.,
- ``int`` or ``bool``. A change in the actual value of static arg causes
- recompilation.
+ properties, to detect when there is a need of recompilation.
.. warning::
This API is experimental and likely to change.
@@ -654,8 +591,6 @@
backward graphs.
decompositions (Dict): A dictionary to define the decomposition of
larger Aten ops into simpler or core Aten ops.
- static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark
- the arguments of the function as static.
Returns:
Returns a ``Callable`` that retains the eager behavior of the original
@@ -672,22 +607,10 @@
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)
-
- The static argnums are used to mark the non-tensor arguments as static. An
- example is as follows where the dropout probability is as argument to the
- original function.
-
- >>> def fn(input, bias, residual, p: float):
- >>> a = torch.add(input, bias)
- >>> b = torch.nn.functional.dropout(a, p, training=True)
- >>> c = b + residual
- >>> return c
- >>> aot_fn = aot_function(fn, print_compile_fn, static_argnums=(3,))
-
"""
- global compile_cache
- if compile_cache is None:
- compile_cache = CompileCache()
+ if static_argnums is not None:
+ raise RuntimeError("static_argnums has been deprecated - manually wrap your function or use torchdynamo.")
+
if bw_compiler is None:
bw_compiler = fw_compiler
aot_config = AOTConfig(
@@ -698,69 +621,26 @@
)
cached_res = None
- fn_id = id(fn)
- fw_compiler_id = id(fw_compiler)
- bw_compiler_id = id(bw_compiler)
-
- if isinstance(static_argnums, int):
- static_argnums = [static_argnums]
- elif static_argnums is not None and len(static_argnums) == 0:
- static_argnums = None
- elif static_argnums is not None:
- static_argnums = list(static_argnums)
- static_argnums.sort()
-
@wraps(fn)
def returned_function(*args, **kwargs):
- global compile_cache
nonlocal cached_res
-
- # Separate out static args if static_argnums is present
- tensor_args = args
- static_args = []
- # TODO - move the hashing part of static_args to C++.
- static_args_hashed = []
- if static_argnums is not None:
- (
- tensor_args,
- static_args,
- static_args_hashed,
- ) = filter_tensor_and_static_args(args, static_argnums)
-
# Now flatten the tensor args
- flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs))
-
- # Check if the fn is already compiled
- num_tensor_args = len(flat_tensor_args)
- flat_args_for_cache = flat_tensor_args + static_args_hashed
- cached_res = compile_cache.at(
- fn_id,
- fw_compiler_id,
- bw_compiler_id,
- num_tensor_args,
- hasher_type,
- *flat_args_for_cache,
- )
+ flat_args, _ = pytree.tree_flatten((args, kwargs))
# Compile the function and save it in the cache
if cached_res is None:
# Save the args_spec for flat_tensor_args to unflatten while tracing
- _, tensor_args_spec = pytree.tree_flatten((tensor_args, kwargs))
+ _, tensor_args_spec = pytree.tree_flatten((args, kwargs))
out_spec = PytreeThunk()
- def flat_fn(*flat_tensor_args):
+ def flat_fn(*flat_args):
# The input are flattened tensor args. Prepare the args in the
# order that original function expects. Add static args as well.
# They will appear as tensor constants in the traced graph.
- nonlocal out_spec, static_args
-
- tensor_args, kwargs = pytree.tree_unflatten(
- flat_tensor_args, tensor_args_spec
+ nonlocal out_spec
+ args, kwargs = pytree.tree_unflatten(
+ flat_args, tensor_args_spec
)
- if static_argnums is None:
- args = tensor_args
- else:
- args = rearrange(tensor_args, static_args, static_argnums)
tree_out = fn(*args, **kwargs)
flat_out, spec = pytree.tree_flatten(tree_out)
for i in flat_out:
@@ -783,50 +663,18 @@
compiled_fn = create_aot_dispatcher_function(
flat_fn,
- flat_tensor_args,
+ flat_args,
aot_config,
)
cached_res = (compiled_fn, out_spec)
- # Save the compiled_fn in the cache
- compile_cache.insert(
- fn_id,
- fw_compiler_id,
- bw_compiler_id,
- num_tensor_args,
- hasher_type,
- cached_res,
- *flat_args_for_cache,
- )
-
cached_fn, out_spec = cached_res
- out = cached_fn(flat_tensor_args)
+ out = cached_fn(flat_args)
return out_spec.unflatten(out)
return returned_function
-def num_of_recompilations():
- """
- Returns the numbers of recompilations since the last time cache was cleared.
- This is equivalent to the number of entries in the compilation cache.
- """
- global compile_cache
- if compile_cache is None:
- return 0
- return compile_cache.size()
-
-
-def clear_compile_cache():
- """
- Clears the compilation cache.
- """
- global compile_cache
- if compile_cache is not None:
- compile_cache.clear()
- compile_cache = None
-
-
def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
"""
Traces the forward and backward graph of :attr:`mod` using torch dispatch
@@ -921,8 +769,8 @@
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
- hasher_type: str = "StaticShapeHasher",
- static_argnums: Optional[Tuple[int]] = None,
+ hasher_type=None,
+ static_argnums=None,
) -> Callable:
assert static_argnums is None
if bw_compiler is None:
diff --git a/functorch/_src/compilers.py b/functorch/_src/compilers.py
index 3dd8455..18deafa 100644
--- a/functorch/_src/compilers.py
+++ b/functorch/_src/compilers.py
@@ -193,7 +193,6 @@
"fw_compiler": ts_compile,
"bw_compiler": ts_compile,
"partition_fn": min_cut_rematerialization_partition,
- "hasher_type": "StaticShapeHasher",
"decompositions": default_decompositions,
"static_argnums": static_argnums,
}
diff --git a/functorch/benchmarks/transformer_fusion_patterns/benchmark.py b/functorch/benchmarks/transformer_fusion_patterns/benchmark.py
index a6646e1..f799422 100644
--- a/functorch/benchmarks/transformer_fusion_patterns/benchmark.py
+++ b/functorch/benchmarks/transformer_fusion_patterns/benchmark.py
@@ -1,5 +1,5 @@
import torch
-from functorch.compile import memory_efficient_fusion, clear_compile_cache
+from functorch.compile import memory_efficient_fusion
import benchmark_helper
@@ -159,7 +159,6 @@
for cl in [DropoutResBias, BiasReluDropout, DropoutResBiasScalar, BiasDropoutResLayerNorm, LayerNormSigmoid]:
# Clear the compile cache
- clear_compile_cache()
# Get the function and inputs
obj = cl()
diff --git a/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py b/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py
index b231806..26c6d7c 100644
--- a/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py
+++ b/functorch/benchmarks/transformer_fusion_patterns/bias_gelu_dropout.py
@@ -1,5 +1,5 @@
import torch
-from functorch.compile import memory_efficient_pointwise_fusion, clear_compile_cache
+from functorch.compile import memory_efficient_pointwise_fusion
import benchmark_helper
# ALL comments regarding the patetrns
@@ -21,7 +21,6 @@
fn = bias_gelu_dropout
-clear_compile_cache()
# Set inputs
device = "cuda"
diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py
index b489eb4..12549dc 100644
--- a/functorch/compile/__init__.py
+++ b/functorch/compile/__init__.py
@@ -5,8 +5,6 @@
aot_module,
compiled_function,
compiled_module,
- num_of_recompilations,
- clear_compile_cache,
aot_module_simplified,
get_graph_being_compiled,
get_aot_graph_name,
diff --git a/functorch/notebooks/aot_autograd_optimizations.ipynb b/functorch/notebooks/aot_autograd_optimizations.ipynb
index f12c2a7..beef4cb 100644
--- a/functorch/notebooks/aot_autograd_optimizations.ipynb
+++ b/functorch/notebooks/aot_autograd_optimizations.ipynb
@@ -127,10 +127,7 @@
"\n",
"# Run the aot_print_fn once to trigger the compilation and print the graphs\n",
"res = aot_print_fn(a, b, c, d).sum().backward()\n",
- "assert torch.allclose(ref, res)\n",
- "\n",
- "from functorch.compile import clear_compile_cache\n",
- "clear_compile_cache()"
+ "assert torch.allclose(ref, res)"
]
},
{
diff --git a/functorch/test/common_utils.py b/functorch/test/common_utils.py
index 65b9dcb..91cba9e 100644
--- a/functorch/test/common_utils.py
+++ b/functorch/test/common_utils.py
@@ -351,7 +351,7 @@
matching_opinfos = [o for o in all_opinfos
if o.name == decorate_meta.op_name and
o.variant_test_name == decorate_meta.variant_name]
- assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {xfail}"
+ assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}"
assert len(matching_opinfos) == 1, (
"OpInfos should be uniquely determined by their (name, variant_name). "
f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})"
diff --git a/functorch/test/test_aotdispatch.py b/functorch/test/test_aotdispatch.py
index 3cce940..bfcabd2 100644
--- a/functorch/test/test_aotdispatch.py
+++ b/functorch/test/test_aotdispatch.py
@@ -24,8 +24,8 @@
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
min_cut_rematerialization_partition, aot_function, aot_module,
- nop, num_of_recompilations, default_partition, default_decompositions,
- memory_efficient_fusion, clear_compile_cache, get_aot_compilation_context
+ nop, default_partition, default_decompositions,
+ memory_efficient_fusion, get_aot_compilation_context
)
from torch._decomp import decomposition_table
@@ -61,9 +61,6 @@
class AOTTestCase(TestCase):
def setUp(self):
super().setUp()
- # NB: We cache on function id, which is unreliable
- # Can fix by using weakrefs, but not sure if it matters
- clear_compile_cache()
class TestPythonKey(AOTTestCase):
def test_make_fx(self, device):
@@ -292,18 +289,17 @@
graph_size = len(fx_g.graph.nodes)
return fx_g
- start_recompilations = num_of_recompilations()
f = aot_function(foo, nop, get_graph_size)
with torch.set_grad_enabled(False):
f(*inps)
self.assertIsNone(graph_size)
+ f = aot_function(foo, nop, get_graph_size)
with torch.set_grad_enabled(True):
out = f(*inps)
self.assertIsNone(graph_size)
out.sum().backward()
self.assertTrue(graph_size > 2)
- self.assertEqual(num_of_recompilations() - start_recompilations, 2)
def test_output_dict(self):
def f(x):
@@ -366,6 +362,7 @@
f = aot_function(f, compiler)
out = f(torch.randn(5, requires_grad=True))
+ f = aot_function(f, compiler)
f(torch.randn(5))
out.sum().backward()
self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])
@@ -731,7 +728,6 @@
res = aot_mod(*inputs)
res[0].sum().backward()
-
only_for = ("cpu")
instantiate_device_type_tests(
TestPythonKey,
diff --git a/functorch/test/test_compile_cache.py b/functorch/test/test_compile_cache.py
deleted file mode 100644
index 2115e58..0000000
--- a/functorch/test/test_compile_cache.py
+++ /dev/null
@@ -1,686 +0,0 @@
-# Owner(s): ["module: functorch"]
-
-import torch
-
-import functorch
-from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
-import unittest
-
-from functorch.compile import aot_function, nop
-
-
-class TestCompileCache(TestCase):
- def check(self, a, b, aot_fn, fn):
- a_clone = a.clone().detach().requires_grad_(True)
- b_clone = b.clone().detach().requires_grad_(True)
- ref = fn(a, b)
- ref.sum().backward()
-
- res = aot_fn(a_clone, b_clone)
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(a.grad, a_clone.grad)
- assert torch.allclose(b.grad, b_clone.grad)
-
- def test_recompilation_on_broadcast(self):
- def fn(x, bias):
- return x + bias
-
- for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
- functorch.compile.clear_compile_cache()
- start_num_recomps = functorch.compile.num_of_recompilations()
- aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
-
- a = torch.randn(10, 20, requires_grad=True)
- b = torch.randn(20, requires_grad=True)
- self.check(a, b, aot_autograd_fn, fn)
-
- a = torch.randn(10, 20, requires_grad=True)
- b = torch.randn(10, 20, requires_grad=True)
- self.check(a, b, aot_autograd_fn, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_compilation_for_dynamic_shape(self):
- def fn(x, bias):
- return x + bias
-
- for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
- functorch.compile.clear_compile_cache()
- start_num_recomps = functorch.compile.num_of_recompilations()
- aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
-
- for s in range(10, 20):
- a = torch.randn(s, requires_grad=True)
- b = torch.randn(s, requires_grad=True)
- self.check(a, b, aot_autograd_fn, fn)
-
- for s in range(10, 20):
- a = torch.randn(s, requires_grad=True)
- b = torch.randn(s, requires_grad=True)
- self.check(a, b, aot_autograd_fn, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- if hasher_type == "DynamicShapeHasher":
- assert total_recomps == 1
- elif hasher_type == "StaticShapeHasher":
- assert total_recomps == 10
-
- for s in range(10, 20):
- a = torch.randn(s, s, requires_grad=True)
- b = torch.randn(s, s, requires_grad=True)
- self.check(a, b, aot_autograd_fn, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- if hasher_type == "DynamicShapeHasher":
- assert total_recomps == 2
- elif hasher_type == "StaticShapeHasher":
- assert total_recomps == 20
-
- def test_global_cache_no_recompilations(self):
- def f(x, bias):
- return x + bias
-
- def g(x, bias):
- return aot_function(f, nop, nop, hasher_type="DynamicShapeHasher")(x, bias)
-
- start_num_recomps = functorch.compile.num_of_recompilations()
- for _ in range(10):
- a = torch.randn(10, 20, requires_grad=True)
- b = torch.randn(10, 20, requires_grad=True)
- self.check(a, b, g, f)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 1
-
- def test_multiple_functions(self):
- def f(x, bias):
- return x + bias
-
- def g(x, y):
- return x * y
-
- for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
- functorch.compile.clear_compile_cache()
- aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
- aot_autograd_g = aot_function(g, nop, nop, hasher_type=hasher_type)
-
- start_num_recomps = functorch.compile.num_of_recompilations()
- a = torch.randn(10, requires_grad=True)
- b = torch.randn(10, requires_grad=True)
- self.check(a, b, aot_autograd_f, f)
-
- a = torch.randn(10, requires_grad=True)
- b = torch.randn(10, requires_grad=True)
- self.check(a, b, aot_autograd_g, g)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- # Force recompilation for function f and check num of recompilations again
- a = torch.randn(10, 20, requires_grad=True)
- b = torch.randn(10, 20, requires_grad=True)
- self.check(a, b, aot_autograd_f, f)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 3
-
- def test_high_number_of_args(self):
- def f(*args):
- res = args[0]
- for arg in args:
- res = res * arg
- return res
-
- def check(args, aot_autograd_fn, fn):
- args_clone = [arg.clone().detach().requires_grad_(True) for arg in args]
- ref = fn(*args)
- ref.sum().backward()
-
- res = aot_autograd_fn(*args_clone)
- res.sum().backward()
- assert torch.allclose(res, ref)
- for (arg, arg_clone) in zip(args, args_clone):
- assert torch.allclose(arg.grad, arg_clone.grad)
-
- for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
- functorch.compile.clear_compile_cache()
-
- aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
-
- args = [torch.randn(10, requires_grad=True) for _ in range(100)]
- check(args, aot_autograd_f, f)
-
- def test_multiple_compiler(self):
- def fn(x, bias):
- return x + bias
-
- def nop_duplicate(fx_g, _):
- return fx_g
-
- for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
- functorch.compile.clear_compile_cache()
- start_num_recomps = functorch.compile.num_of_recompilations()
- nop_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
- nop_duplicate_fn = aot_function(
- fn, nop_duplicate, nop_duplicate, hasher_type=hasher_type
- )
-
- a = torch.randn(10, 20, requires_grad=True)
- b = torch.randn(20, requires_grad=True)
- nop_fn(a, b)
- nop_duplicate_fn(a, b)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
-
-@unittest.skipIf(IS_WINDOWS, 'test broken on windows')
-class TestCompileCacheStaticArgs(TestCase):
- def check(self, a, b, aot_autograd_fn, fn):
- a_clone = a.clone().detach().requires_grad_(True)
- ref = fn(a, b)
- ref.sum().backward()
-
- res = aot_autograd_fn(a_clone, b)
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(a.grad, a_clone.grad)
-
- def test_failure(self):
- # Test that not setting up static_argnums should raise exception
- def fn(x, p):
- return x * p
-
- aot_autograd_f = aot_function(fn, nop, nop)
-
- a = torch.randn(2, 2, requires_grad=True)
- b = 2
- try:
- # Since b is not marked as static, it should raise exception
- aot_autograd_f(a, b)
- raise AssertionError()
- except RuntimeError:
- pass
-
- def test_simple(self):
- def fn(x, static_arg):
- return x * static_arg
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
-
- a = torch.randn(2, 2, requires_grad=True)
- b = 2
- self.check(a, b, aot_autograd_f, fn)
-
- # Same type of args, so no recompilation
- a = torch.randn(2, 2, requires_grad=True)
- b = 2
- self.check(a, b, aot_autograd_f, fn)
-
- # Trigger recompilation
- a = torch.randn(2, 2, requires_grad=True)
- b = 3
- self.check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_static_arg_before_tensor_arg(self):
- def fn(static_arg, x):
- return static_arg - x
-
- def check(a, b, aot_autograd_fn, fn):
- b_clone = b.clone().detach().requires_grad_(True)
- ref = fn(a, b)
- ref.sum().backward()
-
- res = aot_autograd_fn(a, b_clone)
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(b.grad, b_clone.grad)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=0)
-
- a = 2
- b = torch.randn(2, 2, requires_grad=True)
- check(a, b, aot_autograd_f, fn)
-
- a = 3
- b = torch.randn(2, 2, requires_grad=True)
- check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_interleaved_static_args(self):
- def fn(static_arg1, x, static_arg2):
- return static_arg1 - x - static_arg2
-
- def check(a, b, c, aot_autograd_fn, fn):
- b_clone = b.clone().detach().requires_grad_(True)
- ref = fn(a, b, c)
- ref.sum().backward()
-
- res = aot_autograd_fn(a, b_clone, c)
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(b.grad, b_clone.grad)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, 2))
-
- a = 2
- b = torch.randn(2, 2, requires_grad=True)
- c = 0.1
- check(a, b, c, aot_autograd_f, fn)
-
- a = 3
- b = torch.randn(2, 2, requires_grad=True)
- c = 0.1
- check(a, b, c, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_dropout(self):
- def fn(x, prob):
- return torch.nn.functional.dropout(x, p=prob)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
-
- a = torch.randn(2, 2, requires_grad=True)
- b = 0.3
- aot_autograd_f(a, b)
-
- # Setting the prob to 0. This should cause recompilation.
- a = torch.randn(2, 2, requires_grad=True)
- b = 0
- self.check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_if_condition(self):
- def fn(x, state: bool):
- if state:
- return torch.sin(x)
- else:
- return torch.cos(x)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
-
- a = torch.randn(2, 2, requires_grad=True)
- b = True
- self.check(a, b, aot_autograd_f, fn)
-
- a = torch.randn(2, 2, requires_grad=True)
- b = True
- self.check(a, b, aot_autograd_f, fn)
-
- a = torch.randn(2, 2, requires_grad=True)
- b = False
- self.check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_custom(self):
- class Record:
- def __init__(self, name, multiplier):
- self.name = name
- self.multiplier = multiplier
-
- def __eq__(self, other):
- return self.name == other.name and self.multiplier == other.multiplier
-
- def __hash__(self):
- return hash((self.name, self.multiplier))
-
- def fn(x, record):
- return x * record.multiplier
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
-
- a = torch.randn(2, 2, requires_grad=True)
- b = Record("Foo", 0.5)
- self.check(a, b, aot_autograd_f, fn)
-
- a = torch.randn(2, 2, requires_grad=True)
- b = Record("Bar", 10.2)
- self.check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_tuple(self):
- def fn(a_tuple, static_arg):
- return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
-
- def check(a_tuple, b, aot_autograd_fn, fn):
- a0 = a_tuple[0]
- a1 = a_tuple[1]
-
- a0_clone = a0.clone().detach().requires_grad_(True)
- a1_clone = a1.clone().detach().requires_grad_(True)
- ref = fn(a, b)
- ref.sum().backward()
-
- res = aot_autograd_fn((a0_clone, a1_clone), b)
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(a0.grad, a0_clone.grad)
- assert torch.allclose(a1.grad, a1_clone.grad)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
-
- a = (
- torch.randn(2, 2, requires_grad=True),
- torch.randn(2, 2, requires_grad=True),
- )
- b = 0.1
- check(a, b, aot_autograd_f, fn)
-
- a = (
- torch.randn(2, 2, requires_grad=True),
- torch.randn(2, 2, requires_grad=True),
- )
- b = 1
- check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_tuple_with_first_arg_as_static(self):
- def fn(static_arg, a_tuple):
- return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
-
- def check(a, b_tuple, aot_autograd_fn, fn):
- b0 = b_tuple[0]
- b1 = b_tuple[1]
-
- b0_clone = b0.clone().detach().requires_grad_(True)
- b1_clone = b1.clone().detach().requires_grad_(True)
- ref = fn(a, b_tuple)
- ref.sum().backward()
-
- res = aot_autograd_fn(a, (b0_clone, b1_clone))
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(b0.grad, b0_clone.grad)
- assert torch.allclose(b1.grad, b1_clone.grad)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
-
- a = 0.1
- b = (
- torch.randn(2, 2, requires_grad=True),
- torch.randn(2, 2, requires_grad=True),
- )
- check(a, b, aot_autograd_f, fn)
-
- a = 1
- b = (
- torch.randn(2, 2, requires_grad=True),
- torch.randn(2, 2, requires_grad=True),
- )
- check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_dict(self):
- def fn(a_dict, static_arg):
- return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
-
- def check(a_dict, b, aot_autograd_fn, fn):
-
- a0 = a_dict["foo"]
- a1 = a_dict["bar"]
-
- a0_clone = a0.clone().detach().requires_grad_(True)
- a1_clone = a1.clone().detach().requires_grad_(True)
- ref = fn(a_dict, b)
- ref.sum().backward()
-
- a_clone = {}
- a_clone["foo"] = a0_clone
- a_clone["bar"] = a1_clone
- res = aot_autograd_fn(a_clone, b)
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(a0.grad, a0_clone.grad)
- assert torch.allclose(a1.grad, a1_clone.grad)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
-
- a = {}
- a["foo"] = torch.zeros(2, 2, requires_grad=True)
- a["bar"] = torch.ones(2, 2, requires_grad=True)
- b = 0
- check(a, b, aot_autograd_f, fn)
-
- a = {}
- a["foo"] = torch.randn(2, 2, requires_grad=True)
- a["bar"] = torch.randn(2, 2, requires_grad=True)
- b = 0.2
- check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_dict_with_static_arg_before_dict(self):
- def fn(static_arg, a_dict):
- return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
-
- def check(a, b_dict, aot_autograd_fn, fn):
-
- ref = fn(a, b_dict)
- res = aot_autograd_fn(a, b_dict)
- assert torch.allclose(res, ref)
-
- b0 = b_dict["foo"]
- b1 = b_dict["bar"]
-
- b0_clone = b0.clone().detach().requires_grad_(True)
- b1_clone = b1.clone().detach().requires_grad_(True)
- ref.sum().backward()
-
- b_clone = {}
- b_clone["foo"] = b0_clone
- b_clone["bar"] = b1_clone
- res = aot_autograd_fn(a, b_clone)
- res.sum().backward()
- assert torch.allclose(res, ref)
- assert torch.allclose(b0.grad, b0_clone.grad)
- assert torch.allclose(b1.grad, b1_clone.grad)
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
-
- a = 0.1
- b = {}
- b["foo"] = torch.randn(2, 2, requires_grad=True)
- b["bar"] = torch.randn(2, 2, requires_grad=True)
- check(a, b, aot_autograd_f, fn)
-
- a = 0.2
- b = {}
- b["foo"] = torch.randn(2, 2, requires_grad=True)
- b["bar"] = torch.randn(2, 2, requires_grad=True)
- check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_tuple_static_args(self):
- def fn(x, tuple_static_arg):
- return x * tuple_static_arg[0] * tuple_static_arg[1]
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
-
- a = torch.randn(2, 2, requires_grad=True)
- b = (2, 3)
- self.check(a, b, aot_autograd_f, fn)
-
- # Same type of args, so no recompilation
- a = torch.randn(2, 2, requires_grad=True)
- b = (2, 3)
- self.check(a, b, aot_autograd_f, fn)
-
- # Trigger recompilation
- a = torch.randn(2, 2, requires_grad=True)
- b = (3, 4)
- self.check(a, b, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 2
-
- def test_arg_none(self):
- def check(a, b, c, aot_autograd_fn, fn):
- def cloner(x):
- if x is not None:
- return x.clone().detach().requires_grad_(True)
- return None
-
- def check_grad(x, x_clone):
- if x is not None:
- return torch.allclose(x.grad, x_clone.grad)
- return True
-
- ref = fn(a, b, c)
- res = aot_autograd_fn(a, b, c)
- assert torch.allclose(res, ref)
-
- a_clone = cloner(a)
- b_clone = cloner(b)
- c_clone = cloner(c)
- ref.sum().backward()
- res = aot_autograd_fn(a_clone, b_clone, c_clone)
- res.sum().backward()
-
- check_grad(a, a_clone)
- check_grad(b, b_clone)
- check_grad(c, c_clone)
-
- def fn(a, b, c):
- if a is None and b is None:
- return c
- elif a is None and c is None:
- return b
- elif b is None and c is None:
- return a
- elif a is None:
- return b + c
- elif b is None:
- return a + c
- elif c is None:
- return a + b
- return a + b + c
-
- functorch.compile.clear_compile_cache()
-
- start_num_recomps = functorch.compile.num_of_recompilations()
-
- aot_autograd_f = aot_function(fn, nop, nop)
-
- t1 = torch.randn(2, 2, requires_grad=True)
- check(t1, None, None, aot_autograd_f, fn)
- check(None, t1, None, aot_autograd_f, fn)
- check(None, None, t1, aot_autograd_f, fn)
-
- t2 = torch.randn(2, 2, requires_grad=True)
- check(t1, t2, None, aot_autograd_f, fn)
- check(t1, None, t2, aot_autograd_f, fn)
- check(None, t1, t2, aot_autograd_f, fn)
-
- t3 = torch.randn(2, 2, requires_grad=True)
- check(t1, t2, t3, aot_autograd_f, fn)
-
- # Same type of args, so no recompilation
- check(t1, t2, None, aot_autograd_f, fn)
-
- end_num_recomps = functorch.compile.num_of_recompilations()
-
- total_recomps = end_num_recomps - start_num_recomps
- assert total_recomps == 7
-
-
-if __name__ == "__main__":
- run_tests()
diff --git a/functorch/test/test_functionalize.py b/functorch/test/test_functionalize.py
index 0ae8d5c..a04bd49 100644
--- a/functorch/test/test_functionalize.py
+++ b/functorch/test/test_functionalize.py
@@ -4,7 +4,6 @@
from unittest.mock import patch
import functools
from torch.testing._internal.common_utils import run_tests
-import test_compile_cache
import test_aotdispatch
@@ -38,8 +37,6 @@
return FunctionalizeTest
-FunctionalizeTestCompileCache = make_functionalize_test(test_compile_cache.TestCompileCache)
-FunctionalizeTestCompileCacheStaticArgs = make_functionalize_test(test_compile_cache.TestCompileCacheStaticArgs)
FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_aotdispatch.TestAOTAutograd)
FunctionalizeTestPythonKeyPartitioning = make_functionalize_test(test_aotdispatch.TestPartitioning)
diff --git a/torch/csrc/functorch/CompileCache.cpp b/torch/csrc/functorch/CompileCache.cpp
deleted file mode 100644
index 0366e38..0000000
--- a/torch/csrc/functorch/CompileCache.cpp
+++ /dev/null
@@ -1,345 +0,0 @@
-// Copyright (c) Facebook, Inc. and its affiliates.
-// All rights reserved.
-//
-// This source code is licensed under the BSD-style license found in the
-// LICENSE file in the root directory of this source tree.
-
-///
-/// This design stemmed of from the PointwiseOperatorCompileCache with the
-/// purpose of making it more generic for AOTAutograd. This is Compile Cache
-/// allowing different types of hashing functions, and is agnostic of the
-/// compiler.
-///
-#include <torch/csrc/autograd/custom_function.h>
-#include <torch/csrc/functorch/CompileCache.h>
-#include <torch/csrc/jit/python/pybind_utils.h>
-#include <torch/csrc/jit/tensorexpr/codegen.h>
-#include <torch/csrc/utils/pybind.h>
-#include <torch/csrc/utils/python_numbers.h>
-
-using namespace torch::jit::tensorexpr;
-
-namespace {
-/// Record of thread-local state that changes operator behavior.
-struct LocalState {
- c10::impl::LocalDispatchKeySet dispatchModifier;
- bool gradModeEnabled;
-
- at::DispatchKeySet apply(at::DispatchKeySet ks) const {
- return (ks | dispatchModifier.included_) - dispatchModifier.excluded_;
- }
-
- LocalState()
- : dispatchModifier(c10::impl::tls_local_dispatch_key_set()),
- gradModeEnabled(at::GradMode::is_enabled()) {}
-};
-
-/// Helper to pack tensor (dtype, requires grad) into an 8-bit key.
-static uint8_t packFlags(const LocalState& state, const at::Tensor& v) {
- static_assert(
- static_cast<int>(at::ScalarType::NumOptions) < 128, "overflow possible");
- at::ScalarType dtype = v.dtype().toScalarType();
- bool requires_grad = state.gradModeEnabled && v.requires_grad();
- return static_cast<uint8_t>(requires_grad) |
- (static_cast<uint8_t>(dtype) << 1);
-}
-
-using hash_key_t = std::vector<int64_t>;
-/// Per-tensor cache specialization key targetting dynamic shapes. Records
-/// dtype, dispatch options, aliasing, and per-dim contiguity/broadcasting
-/// information.
-
-enum DimFlags {
- /// A leading dimension implicitly added by broadcasting.
- SIZE_MISSING = 1 << 0,
-
- /// Size == 1.
- SIZE_ONE = 1 << 1,
-
- /// Size > 1.
- SIZE_OTHER = 1 << 2,
-
- /// Stride == 0; broadcasting.
- STRIDE_ZERO = 1 << 3,
-
- /// Stride == 1; packed contiguously in memory.
- STRIDE_ONE = 1 << 4,
-
- /// Stride = Stride[i + 1] * Size[i + 1].
- /// Used to collapse dimensions.
- STRIDE_CONTIGUOUS = 1 << 5,
-
- /// Stride = Stride[i - 1] * Size[i - 1].
- /// Used to collapse dimensions in the other direction.
- STRIDE_TRANSPOSED_CONTIGUOUS = 1 << 6, // stride[i-1] * sizes[i-1]
-
- /// Stride must be provided as an argument.
- STRIDE_AS_ARG = 1 << 7,
-};
-
-/// Unique hasher id values to uniquely identify the type of hash. NONE_HASH is
-/// used when a tensor is undefined.
-enum HasherFlags {
- NONE_HASH,
- STATIC_HASH,
- DYNAMIC_HASH,
-};
-
-std::vector<int> genDimFlags(c10::IntArrayRef sizes, c10::IntArrayRef strides) {
- // Pack all the properties for each dimension into a uint8.
- int nDims = sizes.size();
- std::vector<int> dimflags(nDims);
- for (int64_t dim = 0; dim < nDims; ++dim) {
- uint8_t flag =
- (sizes[dim] == 0 ? SIZE_MISSING
- : (sizes[dim] == 1 ? SIZE_ONE : SIZE_OTHER));
- if (strides[dim] == 0) {
- flag |= STRIDE_ZERO;
- } else if (strides[dim] == 1) {
- flag |= STRIDE_ONE;
- } else if (
- dim + 1 < (int64_t)sizes.size() &&
- strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
- flag |= STRIDE_CONTIGUOUS;
- } else if (
- dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
- (dimflags[dim - 1] & STRIDE_CONTIGUOUS) == 0) {
- flag |= STRIDE_TRANSPOSED_CONTIGUOUS;
- } else {
- flag |= STRIDE_AS_ARG;
- }
- dimflags[dim] = flag;
- }
- return dimflags;
-}
-
-hash_key_t dynamic_hasher(const LocalState& state, const at::Tensor& v) {
- hash_key_t hash = {
- DYNAMIC_HASH,
- static_cast<int>(packFlags(state, v)),
- static_cast<int>(state.apply(v.key_set()).raw_repr()),
- static_cast<int>(v.ndimension())};
- auto dimFlags = genDimFlags(v.sizes(), v.strides());
- hash.insert(hash.end(), dimFlags.begin(), dimFlags.end());
- return hash;
-}
-
-/// Per-tensor cache specialization key targetting static shapes. Recordsdtype,
-/// dispatch options, aliasing, and full shapes and strides.
-hash_key_t static_hasher(const LocalState& state, const at::Tensor& v) {
- hash_key_t hash = {
- STATIC_HASH,
- static_cast<int>(packFlags(state, v)),
- static_cast<int>(state.apply(v.key_set()).raw_repr()),
- static_cast<int>(v.ndimension())};
- hash.insert(hash.end(), v.sizes().begin(), v.sizes().end());
- hash.insert(hash.end(), v.strides().begin(), v.strides().end());
- return hash;
-}
-
-/// ArgCompileCache is a templated class allowing plugging of different types of
-/// Hasher/Specialization Keys.
-struct CompileCache {
- public:
- CompileCache() = default;
- ~CompileCache() = default;
-
- /// Array defining groups of aliased tensors.
-
- /// Cache type mapping specialization keys to compiled kernels.
- class vector_hasher {
- public:
- std::size_t operator()(hash_key_t const& vec) const {
- std::size_t seed = vec.size();
- for (auto& i : vec) {
- seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
- }
- return seed;
- }
- };
- using Cache = std::unordered_map<hash_key_t, py::object, vector_hasher>;
-
- /// Compute the set of specialization keys based on the inputs to
- /// the kernel.
- hash_key_t computeCacheKey(
- PyObject* args,
- const std::vector<at::Tensor>& tensorArgs,
- int numTensorArgs,
- const std::string& hasherType,
- int64_t id,
- int64_t fw_compiler_id,
- int64_t bw_compiler_id) {
- LocalState state;
- hash_key_t cacheKey;
- for (int i = 0; i < numTensorArgs; ++i) {
- if (tensorArgs[i].defined()) {
- // Only hash the tensor when its defined.
- if (hasherType == "StaticShapeHasher") {
- auto res = static_hasher(state, tensorArgs[i]);
- cacheKey.insert(cacheKey.end(), res.begin(), res.end());
- } else if (hasherType == "DynamicShapeHasher") {
- auto res = dynamic_hasher(state, tensorArgs[i]);
- cacheKey.insert(cacheKey.end(), res.begin(), res.end());
- }
- } else {
- // Add a value to the cacheKey to indicate a None tensor.
- cacheKey.push_back(NONE_HASH);
- }
- }
- cacheKey.push_back(id);
- cacheKey.push_back(fw_compiler_id);
- cacheKey.push_back(bw_compiler_id);
- cacheKey.push_back(numTensorArgs);
-
- // Cache the non-tensor args. Currently, all the non-tensor args are cached.
- for (int i = numTensorArgs; i < PyTuple_Size(args); i++) {
- PyObject* arg = PyTuple_GET_ITEM(args, i);
- assert(PyLong_Check(arg));
- cacheKey.push_back(THPUtils_unpackLong(arg));
- }
- return cacheKey;
- }
-
- std::vector<at::Tensor> parsePythonArgs(int numTensorArgs, PyObject* args) {
- // Convert to Tensor Args
- std::vector<at::Tensor> tensorArgs(numTensorArgs);
- for (int i = 0; i < numTensorArgs; ++i) {
- PyObject* arg = PyTuple_GET_ITEM(args, i);
- if (arg == Py_None) {
- // If an input tensor is None, add it as an undefined tensor.
- tensorArgs[i] = at::Tensor();
- } else if (!THPVariable_Check(arg)) {
- // Fail if its a non-tensor arg. It should be marked static.
- std::string dtype = Py_TYPE(arg)->tp_name;
- std::string index = std::to_string(i);
- throw std::runtime_error(
- "Found an argument of type " + dtype + " at index " + index +
- ". Non-tensor arguments must be marked static."
- " Please set the static_argnums correctly to "
- "mark the argument at index " +
- index + " static.");
- } else {
- tensorArgs[i] = THPVariable_Unpack(arg);
- }
- }
- return tensorArgs;
- }
-
- /// Check if the function has already been compiled.
- py::object at(
- int64_t id,
- int64_t fw_compiler_id,
- int64_t bw_compiler_id,
- int numTensorArgs,
- const std::string& hasherType,
- PyObject* args) {
- std::vector<at::Tensor> tensorArgs = parsePythonArgs(numTensorArgs, args);
- hash_key_t cacheKey = computeCacheKey(
- args,
- tensorArgs,
- numTensorArgs,
- hasherType,
- id,
- fw_compiler_id,
- bw_compiler_id);
-
- auto item = cache_.find(cacheKey); // protected by GIL
-
- if (C10_LIKELY(item != cache_.end())) {
- return item->second;
- }
- return py::none();
- }
-
- /// Insert a new compiled functions for new tensor properties.
- void insert(
- int64_t id,
- int64_t fw_compiler_id,
- int64_t bw_compiler_id,
- int numTensorArgs,
- const std::string& hasherType,
- const py::object& compileFn,
- PyObject* args) {
- std::vector<at::Tensor> tensorArgs = parsePythonArgs(numTensorArgs, args);
- LocalState state;
- hash_key_t cacheKey = computeCacheKey(
- args,
- tensorArgs,
- numTensorArgs,
- hasherType,
- id,
- fw_compiler_id,
- bw_compiler_id);
- cache_.emplace(cacheKey, compileFn);
- }
-
- int64_t size() const {
- return cache_.size();
- }
-
- /// Clear the cache.
- void clear() {
- cache_.clear();
- }
-
- private:
- /// Compilation cache holding key and the compiled function.
- Cache cache_;
-};
-
-static CompileCache* createCompileCache() {
- return new CompileCache();
-}
-
-} // namespace
-
-namespace torch {
-namespace functorch {
-
-void initCompileCacheBindings(PyObject* module) {
- py::handle te(module);
- py::class_<CompileCache>(te, "CompileCache")
- .def(py::init(&createCompileCache))
- .def(
- "at",
- [](CompileCache& self,
- int64_t id,
- int64_t fw_compiler_id,
- int64_t bw_compiler_id,
- int numTensorArgs,
- const std::string& hasherType,
- py::args args) {
- return self.at(
- id,
- fw_compiler_id,
- bw_compiler_id,
- numTensorArgs,
- hasherType,
- args.ptr());
- })
- .def(
- "insert",
- [](CompileCache& self,
- int64_t id,
- int64_t fw_compiler_id,
- int64_t bw_compiler_id,
- int numTensorArgs,
- const std::string& hasherType,
- const py::object& compileFn,
- py::args args,
- py::kwargs kwargs) {
- self.insert(
- id,
- fw_compiler_id,
- bw_compiler_id,
- numTensorArgs,
- hasherType,
- compileFn,
- args.ptr());
- })
- .def("clear", [](CompileCache& self) { self.clear(); })
- .def("size", [](CompileCache& self) { return self.size(); });
-}
-
-} // namespace functorch
-} // namespace torch
diff --git a/torch/csrc/functorch/CompileCache.h b/torch/csrc/functorch/CompileCache.h
deleted file mode 100644
index c205c18..0000000
--- a/torch/csrc/functorch/CompileCache.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright (c) Facebook, Inc. and its affiliates.
-// All rights reserved.
-//
-// This source code is licensed under the BSD-style license found in the
-// LICENSE file in the root directory of this source tree.
-#pragma once
-
-#include <torch/csrc/utils/pybind.h>
-
-namespace torch {
-namespace functorch {
-
-// CompileCache is the compilation cache used by the AOTAutograd frontend.
-// We're planning on deleting this in favor of torchdynamo's caching mechanism
-// (CompilerCache predates torchdynamo).
-
-/// Initialize python bindings for kernel compilation cache.
-TORCH_API void initCompileCacheBindings(PyObject* module);
-
-} // namespace functorch
-} // namespace torch
diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp
index 4089a93..45df25c 100644
--- a/torch/csrc/functorch/init.cpp
+++ b/torch/csrc/functorch/init.cpp
@@ -16,7 +16,6 @@
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/TensorWrapper.h>
#include <c10/core/AutogradState.h>
-#include <torch/csrc/functorch/CompileCache.h>
// This file contains functorch's Python bindings.
@@ -457,8 +456,6 @@
m.def("is_functorch_wrapped_tensor", [](const Tensor& tensor) {
return maybe_get_level(tensor) != -1;
});
-
- torch::functorch::initCompileCacheBindings(m.ptr());
}
} // namespace impl
diff --git a/torch/cuda/_dynamo_graphs.py b/torch/cuda/_dynamo_graphs.py
index 1d8ae67..56973e9 100644
--- a/torch/cuda/_dynamo_graphs.py
+++ b/torch/cuda/_dynamo_graphs.py
@@ -137,7 +137,6 @@
# these are taken from memory_efficient_fusion()
"fw_compiler": cudagraphs,
"bw_compiler": cudagraphs,
- "hasher_type": "StaticShapeHasher",
}
def _wrapped_bw_compiler(*args, **kwargs):