blob: f814605891ce1eb0d4859d7e8d4e7195cb1b0d14 [file] [log] [blame]
# Owner(s): ["module: ProxyTensor"]
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import unittest
import warnings
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
from torch._subclasses.fake_tensor import DynamicOutputShapeException
from torch._decomp import decomposition_table
from torch.testing._internal.common_device_type import ops
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
from torch.utils._pytree import tree_map
# Copied from functorch
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, True)
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
return (op_name, variant_name, device_type, dtypes, False)
def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db
for xfail in to_skip:
op_name, variant_name, device_type, dtypes, expected_failure = xfail
matching_opinfos = [o for o in all_opinfos
if o.name == op_name and o.variant_test_name == variant_name]
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
for opinfo in matching_opinfos:
decorators = list(opinfo.decorators)
if expected_failure:
decorator = DecorateInfo(unittest.expectedFailure,
test_case_name, base_test_name,
device_type=device_type, dtypes=dtypes)
decorators.append(decorator)
else:
decorator = DecorateInfo(unittest.skip("Skipped!"),
test_case_name, base_test_name,
device_type=device_type, dtypes=dtypes)
decorators.append(decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)
def _create_new_input(x):
if not isinstance(x, torch.Tensor):
return x
if x.dtype != torch.float:
return x + 1
if x.is_leaf:
return torch.rand_like(x, requires_grad=True)
else:
return torch.rand_like(x)
class TestProxyTensor(TestCase):
def _test(self, f, inps):
fx_f = make_fx(f)(*inps)
new_inps = tree_map(_create_new_input, inps)
self.assertEqual(fx_f(*new_inps), f(*new_inps))
def test_make_fx_simple(self, device):
def f(x):
return torch.sin(x)
self._test(f, (torch.randn(3),))
def test_scalar_device(self, device):
def f(a, b):
return a + b
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self, device):
mod = torchvision.models.resnet18()
def f(x):
for a in mod.parameters():
a.grad = None
out = mod(x)
out.sum().backward()
return [a.grad for a in mod.parameters()]
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
self._test(f, [inp])
def test_proxy_tensor(self):
def f_grad(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)
def f_backward(x):
val = x.cos().cos().sum()
val.backward()
return x.grad
for f in [f_grad, f_backward]:
self._test(f, [torch.randn(3, requires_grad=True)])
def test_inplace_metadata(self):
def f(x):
x = x.clone()
x.unsqueeze_(-1)
assert x.shape[-1] == 1
return x
self._test(f, [torch.randn(5)])
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
# default behavior should trace factory functions
traced = make_fx(f)(torch.randn(3))
self.assertTrue(
any(
node.target == torch.ops.aten.randn.default
for node in traced.graph.nodes
)
)
def test_mode_tracing_factory_function_no_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
# setting the flag to false should not trace factory functions
traced = make_fx(f, trace_factory_functions=False)(torch.randn(3))
self.assertFalse(
any(
node.target == torch.ops.aten.randn.default
for node in traced.graph.nodes
)
)
def test_make_fx_overloads(self):
def f(x):
return x.cos() + torch.randn(x.shape)
traced = make_fx(f)(torch.randn(3))
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function']))
def test_tensor_constants(self):
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
self._test(f, [])
def test_constant_proxy_tensor(self):
from torch.fx.experimental.proxy_tensor import make_fx
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
g = make_fx(f)()
self.assertEqual(g(), f())
def test_constant_proxy_tensor_mut(self):
from torch.fx.experimental.proxy_tensor import make_fx
def f():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((100, 100), val)
g = make_fx(f)()
self.assertEqual(g(), f())
# In case we mutated shared state in the g graph!
self.assertEqual(g(), f())
g = make_fx(f, use_fake=True)()
self.assertEqual(g(), f())
# In case we mutated shared state in the g graph!
self.assertEqual(g(), f())
def test_use_fake_and_tensor(self):
def f(x, y):
z = torch.tensor([2.0, 3.0])
return x + y + z
g = make_fx(f, use_fake=True)(torch.randn(2), torch.randn(2))
x, y = torch.randn(2), torch.randn(2)
self.assertEqual(g(x, y), f(x, y))
def test_decomposition_interpreter(self):
def fn(x):
return torch.nn.functional.silu(x)
x = torch.rand((4, 4))
fx_module = make_fx(fn, decomposition_table=None)(x)
found_silu = False
for n in fx_module.graph.nodes:
if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
found_silu = True
self.assertTrue(found_silu)
new_graph = torch.fx.Graph()
silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
DecompositionInterpreter(
fx_module,
new_graph=new_graph,
decomposition_table=silu_decomp_table,
).run(x)
decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
for n in decomposed_module.graph.nodes:
self.assertTrue(n.target != torch.ops.aten.silu)
self.assertTrue(n.target != torch.ops.aten.silu.default)
self.assertEqual(fx_module(x), decomposed_module(x))
def test_make_fx_reentrant_dispatch(self):
def f(x):
return torch.ops.aten.norm.Scalar(x, 2.0)
def norm_decomp(x, p=2.0):
if p != 2.0:
raise RuntimeError("can't handle with p != 2")
return torch.sqrt(torch.sum(torch.square(x)))
decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
traced = make_fx(f, decomposition_table=decomp)(torch.rand(3))
for n in traced.graph.nodes:
self.assertTrue("square" not in str(n.target))
self.assertTrue("norm" not in str(n.target))
make_fx_failures = {
# unknown
xfail('allclose'),
xfail('equal'),
xfail('linalg.eigvals'),
xfail('nn.functional.max_pool1d', device_type='cpu'),
# empty
skip('new_empty'),
skip('empty_like'),
skip('empty'),
# flaky
skip('linalg.lstsq', 'grad_oriented'),
skip('nn.functional.max_unpool1d', '', device_type='cpu'),
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
skip('linalg.lstsq'), # flaky, probably just a precision issue
# data-dependent control flow
xfail('cov'),
xfail('istft'),
xfail('nanquantile'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('quantile'),
xfail('tensor_split'),
xfail('corrcoef'),
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
# ???
xfail('nn.functional.ctc_loss'),
# Sparse tensors are not supported with faketensors for now
xfail('to_sparse'),
# segfaults
skip('block_diag'),
}
fake_tensor_failures = {
# Needs complex-value support
xfail('polar'),
xfail('complex'),
xfail('linalg.eig'),
# FakeTensor fallback doesn't work
xfail('linalg.matrix_power'),
xfail('segment_reduce', 'lengths'),
xfail('multinomial'),
xfail('mvlgamma', 'mvlgamma_p_1'),
xfail('mvlgamma', 'mvlgamma_p_3'),
xfail('mvlgamma', 'mvlgamma_p_5'),
xfail('cholesky'),
xfail('cholesky_inverse'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),
}
def _test_make_fx_helper(self, device, dtype, op, use_fake):
def f(args, kwargs):
return op.op(*args, **kwargs)
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
new_f = None
for sample_input in sample_inputs_itr:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
try:
new_f = make_fx(f, use_fake=use_fake)(args, kwargs)
except DynamicOutputShapeException as e:
self.skipTest("Dynamic output shape operation in trace")
for arg in args:
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
arg.uniform_(0, 1)
try:
old_out = f(args, kwargs)
except Exception:
continue
new_out = wrapper_set_seed(new_f, args, kwargs)
self.assertEqual(new_out, old_out)
class TestProxyTensorOpInfo(TestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
def test_make_fx_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, False)
@ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
def test_make_fx_fake_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, True)
only_for = ("cpu")
instantiate_device_type_tests(
TestProxyTensor,
globals(),
only_for=only_for,
)
instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
if __name__ == '__main__':
run_tests()