| # Owner(s): ["oncall: fx"] |
| |
| import itertools |
| |
| import torch |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.passes.dialect.common.cse_pass import CSEPass |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| TestCase, |
| ) |
| |
| |
| def FactoryFunctionCall(x, device): |
| y = torch.full(x.shape, 3, device=device) |
| z = torch.add(y, x) |
| return z |
| |
| |
| def TorchTensorCall(x): |
| y = torch.tensor(3) |
| return x + y |
| |
| |
| def TakeList(x): |
| z = torch.cat([x, x]) |
| return z |
| |
| |
| def ReturnList(x): |
| a = torch.arange(10).reshape(5, 2) |
| z = torch.split(a, [1, 4]) |
| return z |
| |
| |
| def Mutation(x): |
| y = x + 2 |
| y.add_(1) |
| return x + y |
| |
| |
| def MutationInput(x): |
| x.add_(1) |
| y = x + 2 |
| return x + y |
| |
| |
| def MutationFactory(x, device): |
| y = torch.full(x.shape, 3, device=device) |
| y.add_(1) |
| return x + y |
| |
| |
| def MutationTorchTensorCall(x): |
| y = torch.tensor(3) |
| y.add_(1) |
| return x + y |
| |
| |
| def MutationMetadata(x): |
| x.resize_(2) |
| return x |
| |
| |
| Passes = [CSEPass] |
| Test_Cases = [ |
| TakeList, |
| ReturnList, |
| Mutation, |
| MutationInput, |
| MutationMetadata, |
| MutationTorchTensorCall, |
| ] |
| Factory_Test_Cases = [FactoryFunctionCall, MutationFactory] |
| Devices = ["cpu"] |
| if torch.cuda.is_available(): |
| Devices.append("cuda") |
| |
| |
| def name_fn(common_pass, f, device): |
| """Names parameterized test cases.""" |
| return f"{type(common_pass()).__name__}_{f.__name__}_{device}" |
| |
| |
| @instantiate_parametrized_tests |
| class TestCommonPass(TestCase): |
| @parametrize( |
| "common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn |
| ) |
| def test_correctness(self, common_pass, f, device): |
| inp = torch.randn(10, device=device) |
| |
| traced_m = make_fx(f)(inp) |
| P = common_pass() |
| |
| res = P(traced_m) |
| modified_m = res.graph_module |
| assert isinstance(modified_m, GraphModule) |
| |
| inp_copy = inp.clone() |
| expected = f(inp) |
| result = modified_m(inp_copy) |
| |
| self.assertEqual(result, expected) |
| |
| @parametrize( |
| "common_pass,f,device", |
| itertools.product(Passes, Factory_Test_Cases, Devices), |
| name_fn, |
| ) |
| def test_correctness_factory(self, common_pass, f, device): |
| inp = torch.randn(10, device=device) |
| traced_m = make_fx(f)(inp, device) |
| P = common_pass() |
| |
| res = P(traced_m) |
| modified_m = res.graph_module |
| assert isinstance(modified_m, GraphModule) |
| |
| inp_copy = inp.clone() |
| expected = f(inp, device) |
| result = modified_m(inp_copy, device) |
| |
| self.assertEqual(result, expected) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |