blob: 2c4231f2ddbf738a4a814eba0a16da82508cafab [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch.nn as nn
import torch.utils._pytree as pytree
import unittest
import warnings
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from functorch import (
grad, vjp, vmap, jacrev,
make_fx
)
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
partition_with_recompute_fwd_in_bwd, pythonkey_decompose, aot_function, aot_module
)
from torch.testing._internal.common_device_type import ops
from functorch_lagging_op_db import functorch_lagging_op_db
from functorch_additional_op_db import additional_op_db
from common_utils import (
xfail,
skip,
skipOps,
)
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)
# NB: numpy is a testing dependency!
class TestPythonKey(TestCase):
def test_make_fx(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(3)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_grad(self, device):
def f(x):
return torch.sin(x).sum()
inp = torch.randn(3)
f = grad(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_scalar_device(self):
def f(a, b):
return a + b
inps = [torch.randn(3, device='cuda'), torch.tensor(5)]
fx_f = make_fx(f)(*inps)
self.assertEqual(fx_f(*inps), f(*inps))
def test_make_fx_vmap(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(5, 3)
f = vmap(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(5, 3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_jacrev(self, device):
def f(x):
return x.sin().sum()
inp = torch.randn(3)
f = jacrev(jacrev(f))
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_vjp(self, device):
def f(x):
return torch.sin(x).sum()
primals = torch.randn(3)
_, vjp_fn = vjp(f, primals)
cotangent = torch.randn(())
fx_f = make_fx(vjp_fn)(cotangent, True, True)
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
def test_make_fx_no_decompose(self, device):
def f(x):
return torch.tanh(x).sum()
fx_f = make_fx(grad(f))(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
self.assertEqual(torch.ops.aten.tanh_backward in ops, True)
with pythonkey_decompose():
fx_f = make_fx(grad(f))(torch.randn(5))
ops = set([i.target for i in fx_f.graph.nodes])
self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
@unittest.expectedFailure
def test_nnc_jit(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(3)
self.assertEqual(jit_f(inp), f(inp))
@unittest.expectedFailure
def test_nnc_jit_warns_on_recompilation(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(3)
jit_f(inp)
inp2 = torch.randn(5)
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("always")
jit_f(inp2)
self.assertEqual(len(warns), 1)
self.assertTrue("Recompiling" in str(warns[-1].message))
@unittest.expectedFailure
def test_nnc_scalar(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(())
self.assertEqual(jit_f(inp), f(inp))
@unittest.expectedFailure
def test_nnc_pytrees(self, device):
def f(x):
return [torch.sin(x[0])]
jit_f = nnc_jit(f)
inp = [torch.randn(3)]
self.assertEqual(jit_f(inp), f(inp))
def test_external_calls(self, device):
def f(a, b):
return torch.mv(a, b)
jit_f = nnc_jit(f)
inp = [torch.randn(3, 3), torch.randn(3)]
self.assertEqual(jit_f(*inp), f(*inp))
@unittest.expectedFailure
def test_nnc_passthrough(self, device):
def f(x, y):
return x + y, y
inp = (torch.randn(3), torch.randn(3))
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
def f(x):
x['a'] = x['a'] * 2
return x
inp = ({'a': torch.randn(3), 'b': torch.randn(3)},)
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self, device):
mod = torchvision.models.resnet18()
def f(x):
out = mod(x)
out.sum().backward()
return [a.grad for a in mod.parameters()]
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
grads = f(inp)
mod.zero_grad()
mod(inp).sum().backward()
grads2 = [a.grad for a in mod.parameters()]
self.assertEqual(grads, grads2)
make_fx_failures = {
xfail('to_sparse'),
xfail('allclose'),
xfail('rsub', 'rsub_scalar'),
xfail('linalg.matrix_power'),
xfail('linalg.inv'),
xfail('linalg.cholesky'),
xfail('nn.functional.dropout'),
xfail('linalg.eigvals'),
xfail('nn.functional.ctc_loss'),
xfail('nn.functional.fractional_max_pool3d', device_type='cpu'),
xfail('randn_like'), # randomness
xfail('rand_like'), # randomness
xfail('randint_like'), # randomness
skip('new_empty'), # nondeterministic
skip('empty_like'), # nondeterministic
skip('linalg.lstsq', 'grad_oriented'), # flaky
}
class TestPythonKeyOperatorsOpInfo(TestCase):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', make_fx_failures
)
def test_make_fx_exhaustive(self, device, dtype, op):
def f(args, kwargs):
return op.op(*args, **kwargs)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
new_f = None
for sample_input in sample_inputs_itr:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
t = f(args, kwargs)
# just since pytrees with torch.return_types doesn't work
if isinstance(t, tuple):
self.skipTest("output is a tuple that pytree doesn't work with")
new_f = make_fx(f)(args, kwargs)
for arg in args:
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
arg.uniform_(0, 1)
try:
old_out = f(args, kwargs)
except Exception:
continue
new_out = new_f(args, kwargs)
self.assertEqual(new_out, old_out)
pass
def _nop_compile(x, _):
return x
def _outs_and_grads(fn, inps):
outs = fn(*inps)
[out.sum().backward(retain_graph=True) for out in pytree.tree_flatten(outs)[0]]
grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
for inp in pytree.tree_flatten(inps)[0]:
inp.grad = None
return outs, grads
class TestAOTAutograd(TestCase):
def verify_aot_autograd(self, f, inp):
if isinstance(f, nn.Module):
compiled_f = aot_module(f, _nop_compile, _nop_compile)
else:
compiled_f = aot_function(f, _nop_compile, _nop_compile)
ref_out, ref_grad = _outs_and_grads(f, inp)
test_out, test_grad = _outs_and_grads(compiled_f, inp)
self.assertEqual(ref_out, test_out)
self.assertEqual(ref_grad, test_grad)
def test_single_output(self):
def f(a, b):
return a + b
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_multi_output(self):
def f(a, b):
return a + b, a - b
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_multi_output_list(self):
def f(a, b):
return [a + b, a - b]
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_output_dict(self):
def f(x):
return {'a': x, 'b': x}
inp = [torch.randn(3, 3, requires_grad=True)]
self.verify_aot_autograd(f, inp)
def f(x, y):
return {'a': x, 'b': y + x}
inp = [torch.randn(3, requires_grad=True), torch.randn(3)]
self.verify_aot_autograd(f, inp)
def f(x):
new_d = {}
for k in x:
new_d[k] = x[k] * 2
return new_d
inp = [{'a': torch.randn(3, requires_grad=True), 'b': torch.randn(3, requires_grad=True)}]
self.verify_aot_autograd(f, inp)
def test_module(self):
mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
compiled_mod = compiled_module(mod, _nop_compile, _nop_compile)
inp = torch.randn(32, 32)
ref_out = mod(inp)
ref_out.sum().backward()
ref_grads = [p.grad for p in mod.parameters()]
out = compiled_mod(inp)
out.sum().backward()
grads = [p.grad for p in compiled_mod.parameters()]
self.assertEqual((out, grads), (ref_out, ref_grads))
def test_batchnorm(self):
mod = compiled_module(nn.BatchNorm2d(4), _nop_compile, _nop_compile)
x = torch.ones(1, 4, 2, 2)
mod(x).sum().backward()
class TestEagerFusionOpInfo(TestCase):
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
# entries in here need don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', {
xfail('__rmatmul__'),
xfail('linalg.cholesky'),
xfail('linalg.det'),
xfail('linalg.inv'),
xfail('matmul'),
xfail('nn.functional.linear'),
xfail('nn.functional.dropout'),
xfail('polar'),
xfail('special.zeta', 'grad'),
xfail('to_sparse'),
xfail('addcdiv'),
xfail('angle'),
xfail('cholesky'),
xfail('cumulative_trapezoid'),
xfail('diag_embed'),
xfail('linalg.householder_product'),
xfail('logit'),
xfail('matrix_exp'),
xfail('sgn'),
xfail('trapezoid'),
xfail('trapz'),
})
def test_aot_autograd_exhaustive(self, device, dtype, op):
def f(args, kwargs):
return op.op(*args, **kwargs)
if not op.supports_autograd:
return
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in sample_inputs_itr:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in args]):
self.skipTest("not all inputs are float tensors")
if not all([isinstance(i, torch.Tensor) and i.dtype == torch.float for i in kwargs.values()]):
self.skipTest("not all inputs are float tensors")
continue
t = f(args, kwargs)
if isinstance(t, tuple):
self.skipTest("output is a tuple")
continue
def reset_grads():
def f(x):
x.grad = None
pytree.tree_map(f, args)
def get_grads(args):
return pytree.tree_map(lambda x: x.grad, args)
compiled_f = compiled_function(f, lambda x, _: x, lambda x, _: x)
reset_grads()
compiled_f(args, kwargs).sum().backward()
compiled_grad = get_grads(args)
reset_grads()
f(args, kwargs).sum().backward()
orig_grad = get_grads(args)
self.assertEqual(orig_grad, compiled_grad)
def create_new_arg(x):
return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad)
args = pytree.tree_map(create_new_arg, args)
reset_grads()
compiled_f(args, kwargs).sum().backward()
compiled_grad = get_grads(args)
reset_grads()
f(args, kwargs).sum().backward()
orig_grad = get_grads(args)
self.assertEqual(orig_grad, compiled_grad)
class TestPartitioning(TestCase):
def test_recompute_partitioning(self):
def fn(a, b):
return torch.sin(torch.sin(a)) + b
# Reference calculation
ref_a = torch.rand(10, 10, requires_grad=True)
ref_b = torch.rand(10, 10, requires_grad=True)
ref = fn(ref_a, ref_b)
ref.sum().backward()
# Compiled function calculation
res_a = ref_a.clone().detach().requires_grad_(True)
res_b = ref_b.clone().detach().requires_grad_(True)
def compile_fn(x, _):
return x
compiled_fn = compiled_function(fn, compile_fn, compile_fn, partition_with_recompute_fwd_in_bwd)
res = compiled_fn(res_a, res_b)
res.sum().backward()
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3)
assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3)
only_for = ("cpu")
instantiate_device_type_tests(
TestPythonKey,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(TestPythonKeyOperatorsOpInfo, globals(), only_for=only_for)
instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()