blob: 642da7255b681bb7ffc24727923754c992abf3e2 [file] [log] [blame]
# Owner(s): ["oncall: fx"]
import itertools
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph_module import GraphModule
from torch.fx.passes.dialect.common.cse_pass import CSEPass
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
def FactoryFunctionCall(x, device):
y = torch.full(x.shape, 3, device=device)
z = torch.add(y, x)
return z
def TorchTensorCall(x):
y = torch.tensor(3)
return x + y
def TakeList(x):
z = torch.cat([x, x])
return z
def ReturnList(x):
a = torch.arange(10).reshape(5, 2)
z = torch.split(a, [1, 4])
return z
def Mutation(x):
y = x + 2
y.add_(1)
return x + y
def MutationInput(x):
x.add_(1)
y = x + 2
return x + y
def MutationFactory(x, device):
y = torch.full(x.shape, 3, device=device)
y.add_(1)
return x + y
def MutationTorchTensorCall(x):
y = torch.tensor(3)
y.add_(1)
return x + y
def MutationMetadata(x):
x.resize_(2)
return x
Passes = [CSEPass]
Test_Cases = [
TakeList,
ReturnList,
Mutation,
MutationInput,
MutationMetadata,
MutationTorchTensorCall,
]
Factory_Test_Cases = [FactoryFunctionCall, MutationFactory]
Devices = ["cpu"]
if torch.cuda.is_available():
Devices.append("cuda")
def name_fn(common_pass, f, device):
"""Names parameterized test cases."""
return f"{type(common_pass()).__name__}_{f.__name__}_{device}"
@instantiate_parametrized_tests
class TestCommonPass(TestCase):
@parametrize(
"common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn
)
def test_correctness(self, common_pass, f, device):
inp = torch.randn(10, device=device)
traced_m = make_fx(f)(inp)
P = common_pass()
res = P(traced_m)
modified_m = res.graph_module
assert isinstance(modified_m, GraphModule)
inp_copy = inp.clone()
expected = f(inp)
result = modified_m(inp_copy)
self.assertEqual(result, expected)
@parametrize(
"common_pass,f,device",
itertools.product(Passes, Factory_Test_Cases, Devices),
name_fn,
)
def test_correctness_factory(self, common_pass, f, device):
inp = torch.randn(10, device=device)
traced_m = make_fx(f)(inp, device)
P = common_pass()
res = P(traced_m)
modified_m = res.graph_module
assert isinstance(modified_m, GraphModule)
inp_copy = inp.clone()
expected = f(inp, device)
result = modified_m(inp_copy, device)
self.assertEqual(result, expected)
if __name__ == "__main__":
run_tests()