blob: 2115e58845f3a1c3b0245e246474477b1f9a5c3d [file] [log] [blame]
# Owner(s): ["module: functorch"]
import torch
import functorch
from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
import unittest
from functorch.compile import aot_function, nop
class TestCompileCache(TestCase):
def check(self, a, b, aot_fn, fn):
a_clone = a.clone().detach().requires_grad_(True)
b_clone = b.clone().detach().requires_grad_(True)
ref = fn(a, b)
ref.sum().backward()
res = aot_fn(a_clone, b_clone)
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(a.grad, a_clone.grad)
assert torch.allclose(b.grad, b_clone.grad)
def test_recompilation_on_broadcast(self):
def fn(x, bias):
return x + bias
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
a = torch.randn(10, 20, requires_grad=True)
b = torch.randn(20, requires_grad=True)
self.check(a, b, aot_autograd_fn, fn)
a = torch.randn(10, 20, requires_grad=True)
b = torch.randn(10, 20, requires_grad=True)
self.check(a, b, aot_autograd_fn, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_compilation_for_dynamic_shape(self):
def fn(x, bias):
return x + bias
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
for s in range(10, 20):
a = torch.randn(s, requires_grad=True)
b = torch.randn(s, requires_grad=True)
self.check(a, b, aot_autograd_fn, fn)
for s in range(10, 20):
a = torch.randn(s, requires_grad=True)
b = torch.randn(s, requires_grad=True)
self.check(a, b, aot_autograd_fn, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
if hasher_type == "DynamicShapeHasher":
assert total_recomps == 1
elif hasher_type == "StaticShapeHasher":
assert total_recomps == 10
for s in range(10, 20):
a = torch.randn(s, s, requires_grad=True)
b = torch.randn(s, s, requires_grad=True)
self.check(a, b, aot_autograd_fn, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
if hasher_type == "DynamicShapeHasher":
assert total_recomps == 2
elif hasher_type == "StaticShapeHasher":
assert total_recomps == 20
def test_global_cache_no_recompilations(self):
def f(x, bias):
return x + bias
def g(x, bias):
return aot_function(f, nop, nop, hasher_type="DynamicShapeHasher")(x, bias)
start_num_recomps = functorch.compile.num_of_recompilations()
for _ in range(10):
a = torch.randn(10, 20, requires_grad=True)
b = torch.randn(10, 20, requires_grad=True)
self.check(a, b, g, f)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 1
def test_multiple_functions(self):
def f(x, bias):
return x + bias
def g(x, y):
return x * y
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
functorch.compile.clear_compile_cache()
aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
aot_autograd_g = aot_function(g, nop, nop, hasher_type=hasher_type)
start_num_recomps = functorch.compile.num_of_recompilations()
a = torch.randn(10, requires_grad=True)
b = torch.randn(10, requires_grad=True)
self.check(a, b, aot_autograd_f, f)
a = torch.randn(10, requires_grad=True)
b = torch.randn(10, requires_grad=True)
self.check(a, b, aot_autograd_g, g)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
# Force recompilation for function f and check num of recompilations again
a = torch.randn(10, 20, requires_grad=True)
b = torch.randn(10, 20, requires_grad=True)
self.check(a, b, aot_autograd_f, f)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 3
def test_high_number_of_args(self):
def f(*args):
res = args[0]
for arg in args:
res = res * arg
return res
def check(args, aot_autograd_fn, fn):
args_clone = [arg.clone().detach().requires_grad_(True) for arg in args]
ref = fn(*args)
ref.sum().backward()
res = aot_autograd_fn(*args_clone)
res.sum().backward()
assert torch.allclose(res, ref)
for (arg, arg_clone) in zip(args, args_clone):
assert torch.allclose(arg.grad, arg_clone.grad)
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
functorch.compile.clear_compile_cache()
aot_autograd_f = aot_function(f, nop, nop, hasher_type=hasher_type)
args = [torch.randn(10, requires_grad=True) for _ in range(100)]
check(args, aot_autograd_f, f)
def test_multiple_compiler(self):
def fn(x, bias):
return x + bias
def nop_duplicate(fx_g, _):
return fx_g
for hasher_type in ["DynamicShapeHasher", "StaticShapeHasher"]:
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
nop_fn = aot_function(fn, nop, nop, hasher_type=hasher_type)
nop_duplicate_fn = aot_function(
fn, nop_duplicate, nop_duplicate, hasher_type=hasher_type
)
a = torch.randn(10, 20, requires_grad=True)
b = torch.randn(20, requires_grad=True)
nop_fn(a, b)
nop_duplicate_fn(a, b)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
@unittest.skipIf(IS_WINDOWS, 'test broken on windows')
class TestCompileCacheStaticArgs(TestCase):
def check(self, a, b, aot_autograd_fn, fn):
a_clone = a.clone().detach().requires_grad_(True)
ref = fn(a, b)
ref.sum().backward()
res = aot_autograd_fn(a_clone, b)
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(a.grad, a_clone.grad)
def test_failure(self):
# Test that not setting up static_argnums should raise exception
def fn(x, p):
return x * p
aot_autograd_f = aot_function(fn, nop, nop)
a = torch.randn(2, 2, requires_grad=True)
b = 2
try:
# Since b is not marked as static, it should raise exception
aot_autograd_f(a, b)
raise AssertionError()
except RuntimeError:
pass
def test_simple(self):
def fn(x, static_arg):
return x * static_arg
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
a = torch.randn(2, 2, requires_grad=True)
b = 2
self.check(a, b, aot_autograd_f, fn)
# Same type of args, so no recompilation
a = torch.randn(2, 2, requires_grad=True)
b = 2
self.check(a, b, aot_autograd_f, fn)
# Trigger recompilation
a = torch.randn(2, 2, requires_grad=True)
b = 3
self.check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_static_arg_before_tensor_arg(self):
def fn(static_arg, x):
return static_arg - x
def check(a, b, aot_autograd_fn, fn):
b_clone = b.clone().detach().requires_grad_(True)
ref = fn(a, b)
ref.sum().backward()
res = aot_autograd_fn(a, b_clone)
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(b.grad, b_clone.grad)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=0)
a = 2
b = torch.randn(2, 2, requires_grad=True)
check(a, b, aot_autograd_f, fn)
a = 3
b = torch.randn(2, 2, requires_grad=True)
check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_interleaved_static_args(self):
def fn(static_arg1, x, static_arg2):
return static_arg1 - x - static_arg2
def check(a, b, c, aot_autograd_fn, fn):
b_clone = b.clone().detach().requires_grad_(True)
ref = fn(a, b, c)
ref.sum().backward()
res = aot_autograd_fn(a, b_clone, c)
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(b.grad, b_clone.grad)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0, 2))
a = 2
b = torch.randn(2, 2, requires_grad=True)
c = 0.1
check(a, b, c, aot_autograd_f, fn)
a = 3
b = torch.randn(2, 2, requires_grad=True)
c = 0.1
check(a, b, c, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_dropout(self):
def fn(x, prob):
return torch.nn.functional.dropout(x, p=prob)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
a = torch.randn(2, 2, requires_grad=True)
b = 0.3
aot_autograd_f(a, b)
# Setting the prob to 0. This should cause recompilation.
a = torch.randn(2, 2, requires_grad=True)
b = 0
self.check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_if_condition(self):
def fn(x, state: bool):
if state:
return torch.sin(x)
else:
return torch.cos(x)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
a = torch.randn(2, 2, requires_grad=True)
b = True
self.check(a, b, aot_autograd_f, fn)
a = torch.randn(2, 2, requires_grad=True)
b = True
self.check(a, b, aot_autograd_f, fn)
a = torch.randn(2, 2, requires_grad=True)
b = False
self.check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_custom(self):
class Record:
def __init__(self, name, multiplier):
self.name = name
self.multiplier = multiplier
def __eq__(self, other):
return self.name == other.name and self.multiplier == other.multiplier
def __hash__(self):
return hash((self.name, self.multiplier))
def fn(x, record):
return x * record.multiplier
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=[1])
a = torch.randn(2, 2, requires_grad=True)
b = Record("Foo", 0.5)
self.check(a, b, aot_autograd_f, fn)
a = torch.randn(2, 2, requires_grad=True)
b = Record("Bar", 10.2)
self.check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_tuple(self):
def fn(a_tuple, static_arg):
return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
def check(a_tuple, b, aot_autograd_fn, fn):
a0 = a_tuple[0]
a1 = a_tuple[1]
a0_clone = a0.clone().detach().requires_grad_(True)
a1_clone = a1.clone().detach().requires_grad_(True)
ref = fn(a, b)
ref.sum().backward()
res = aot_autograd_fn((a0_clone, a1_clone), b)
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(a0.grad, a0_clone.grad)
assert torch.allclose(a1.grad, a1_clone.grad)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
a = (
torch.randn(2, 2, requires_grad=True),
torch.randn(2, 2, requires_grad=True),
)
b = 0.1
check(a, b, aot_autograd_f, fn)
a = (
torch.randn(2, 2, requires_grad=True),
torch.randn(2, 2, requires_grad=True),
)
b = 1
check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_tuple_with_first_arg_as_static(self):
def fn(static_arg, a_tuple):
return torch.sin(a_tuple[0]) - a_tuple[1] - static_arg
def check(a, b_tuple, aot_autograd_fn, fn):
b0 = b_tuple[0]
b1 = b_tuple[1]
b0_clone = b0.clone().detach().requires_grad_(True)
b1_clone = b1.clone().detach().requires_grad_(True)
ref = fn(a, b_tuple)
ref.sum().backward()
res = aot_autograd_fn(a, (b0_clone, b1_clone))
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(b0.grad, b0_clone.grad)
assert torch.allclose(b1.grad, b1_clone.grad)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
a = 0.1
b = (
torch.randn(2, 2, requires_grad=True),
torch.randn(2, 2, requires_grad=True),
)
check(a, b, aot_autograd_f, fn)
a = 1
b = (
torch.randn(2, 2, requires_grad=True),
torch.randn(2, 2, requires_grad=True),
)
check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_dict(self):
def fn(a_dict, static_arg):
return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
def check(a_dict, b, aot_autograd_fn, fn):
a0 = a_dict["foo"]
a1 = a_dict["bar"]
a0_clone = a0.clone().detach().requires_grad_(True)
a1_clone = a1.clone().detach().requires_grad_(True)
ref = fn(a_dict, b)
ref.sum().backward()
a_clone = {}
a_clone["foo"] = a0_clone
a_clone["bar"] = a1_clone
res = aot_autograd_fn(a_clone, b)
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(a0.grad, a0_clone.grad)
assert torch.allclose(a1.grad, a1_clone.grad)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(1,))
a = {}
a["foo"] = torch.zeros(2, 2, requires_grad=True)
a["bar"] = torch.ones(2, 2, requires_grad=True)
b = 0
check(a, b, aot_autograd_f, fn)
a = {}
a["foo"] = torch.randn(2, 2, requires_grad=True)
a["bar"] = torch.randn(2, 2, requires_grad=True)
b = 0.2
check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_dict_with_static_arg_before_dict(self):
def fn(static_arg, a_dict):
return torch.sin(a_dict["foo"]) - a_dict["bar"] - static_arg
def check(a, b_dict, aot_autograd_fn, fn):
ref = fn(a, b_dict)
res = aot_autograd_fn(a, b_dict)
assert torch.allclose(res, ref)
b0 = b_dict["foo"]
b1 = b_dict["bar"]
b0_clone = b0.clone().detach().requires_grad_(True)
b1_clone = b1.clone().detach().requires_grad_(True)
ref.sum().backward()
b_clone = {}
b_clone["foo"] = b0_clone
b_clone["bar"] = b1_clone
res = aot_autograd_fn(a, b_clone)
res.sum().backward()
assert torch.allclose(res, ref)
assert torch.allclose(b0.grad, b0_clone.grad)
assert torch.allclose(b1.grad, b1_clone.grad)
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=(0,))
a = 0.1
b = {}
b["foo"] = torch.randn(2, 2, requires_grad=True)
b["bar"] = torch.randn(2, 2, requires_grad=True)
check(a, b, aot_autograd_f, fn)
a = 0.2
b = {}
b["foo"] = torch.randn(2, 2, requires_grad=True)
b["bar"] = torch.randn(2, 2, requires_grad=True)
check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_tuple_static_args(self):
def fn(x, tuple_static_arg):
return x * tuple_static_arg[0] * tuple_static_arg[1]
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop, static_argnums=1)
a = torch.randn(2, 2, requires_grad=True)
b = (2, 3)
self.check(a, b, aot_autograd_f, fn)
# Same type of args, so no recompilation
a = torch.randn(2, 2, requires_grad=True)
b = (2, 3)
self.check(a, b, aot_autograd_f, fn)
# Trigger recompilation
a = torch.randn(2, 2, requires_grad=True)
b = (3, 4)
self.check(a, b, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 2
def test_arg_none(self):
def check(a, b, c, aot_autograd_fn, fn):
def cloner(x):
if x is not None:
return x.clone().detach().requires_grad_(True)
return None
def check_grad(x, x_clone):
if x is not None:
return torch.allclose(x.grad, x_clone.grad)
return True
ref = fn(a, b, c)
res = aot_autograd_fn(a, b, c)
assert torch.allclose(res, ref)
a_clone = cloner(a)
b_clone = cloner(b)
c_clone = cloner(c)
ref.sum().backward()
res = aot_autograd_fn(a_clone, b_clone, c_clone)
res.sum().backward()
check_grad(a, a_clone)
check_grad(b, b_clone)
check_grad(c, c_clone)
def fn(a, b, c):
if a is None and b is None:
return c
elif a is None and c is None:
return b
elif b is None and c is None:
return a
elif a is None:
return b + c
elif b is None:
return a + c
elif c is None:
return a + b
return a + b + c
functorch.compile.clear_compile_cache()
start_num_recomps = functorch.compile.num_of_recompilations()
aot_autograd_f = aot_function(fn, nop, nop)
t1 = torch.randn(2, 2, requires_grad=True)
check(t1, None, None, aot_autograd_f, fn)
check(None, t1, None, aot_autograd_f, fn)
check(None, None, t1, aot_autograd_f, fn)
t2 = torch.randn(2, 2, requires_grad=True)
check(t1, t2, None, aot_autograd_f, fn)
check(t1, None, t2, aot_autograd_f, fn)
check(None, t1, t2, aot_autograd_f, fn)
t3 = torch.randn(2, 2, requires_grad=True)
check(t1, t2, t3, aot_autograd_f, fn)
# Same type of args, so no recompilation
check(t1, t2, None, aot_autograd_f, fn)
end_num_recomps = functorch.compile.num_of_recompilations()
total_recomps = end_num_recomps - start_num_recomps
assert total_recomps == 7
if __name__ == "__main__":
run_tests()