|  | # Owner(s): ["oncall: fx"] | 
|  |  | 
|  | import torch | 
|  |  | 
|  | from torch.testing._internal.common_utils import ( | 
|  | TestCase, run_tests) | 
|  | from torch.fx.experimental.proxy_tensor import make_fx | 
|  | from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops | 
|  | from torch.fx import symbolic_trace | 
|  |  | 
|  | import random | 
|  |  | 
|  |  | 
|  | banned_ops = get_CSE_banned_ops() | 
|  | P_default = CSEPass(banned_ops=banned_ops) | 
|  |  | 
|  | def check(self, f, t, delta, check_val=True, graph_input=False, P=None): | 
|  | """ | 
|  | check if the CSE modified graph of ``f`` | 
|  | 1) has delta less nodes, and | 
|  | 2) do not reduce the number of nodes further on a second pass, and | 
|  | 3) modified returned is true only if the number of nodes decreases. | 
|  |  | 
|  | Args: | 
|  | f: function to be checked | 
|  | t: tensor to be passed to f | 
|  | delta: an integer >= -1. | 
|  | If delta = -1, it only checks if the new graph has less or equal number of nodes | 
|  | check_val: if True, check if the output of f is correct | 
|  | graph_input: True is f is type GraphModule | 
|  | P: the pass to use. If None, use P_default | 
|  | """ | 
|  | if graph_input: | 
|  | fx_g = f | 
|  | else: | 
|  | fx_g = make_fx(f)(t) | 
|  |  | 
|  | if P is None: | 
|  | P = P_default | 
|  |  | 
|  | res = P(fx_g) | 
|  | new_g = res.graph_module | 
|  | new_graph = new_g.graph | 
|  | modified = res.modified | 
|  |  | 
|  | # the number of nodes decrease/ or stay the same | 
|  | old_num_nodes = len(fx_g.graph.nodes) | 
|  | new_num_nodes = len(new_graph.nodes) | 
|  |  | 
|  | assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease" | 
|  |  | 
|  | if delta == -1: | 
|  | self.assertTrue(old_num_nodes >= new_num_nodes, ( | 
|  | f"number of nodes increased {old_num_nodes}, {new_num_nodes}")) | 
|  | else: | 
|  | self.assertTrue(old_num_nodes == new_num_nodes + delta, ( | 
|  | f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}")) | 
|  |  | 
|  | # a second pass should not reduce more nodes | 
|  | res = P(new_g) | 
|  | pass_2_graph = res.graph_module.graph | 
|  | pass_2_num_nodes = len(pass_2_graph.nodes) | 
|  | self.assertTrue(pass_2_num_nodes == new_num_nodes, ( | 
|  | f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}")) | 
|  |  | 
|  | # check correctness | 
|  | if check_val: | 
|  | true_result = fx_g(t) | 
|  | our_result = new_g(t) | 
|  | if true_result is None:  # both return None | 
|  | self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}") | 
|  | else:  # results returned are the same | 
|  | self.assertTrue(torch.all(true_result == our_result), ( | 
|  | f"results are different {true_result}, {our_result}"))  # check results are the same | 
|  |  | 
|  | class TestCSEPass(TestCase): | 
|  |  | 
|  | def test_nochange(self): | 
|  | def f(x): | 
|  | a = x + 1 | 
|  | b = x + a | 
|  | a = x | 
|  | d = x + a | 
|  | return b + d | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 0) | 
|  |  | 
|  | def test_empty(self): | 
|  | def f(x): | 
|  | pass | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 0) | 
|  |  | 
|  |  | 
|  | def test_immutable_list_type(self): | 
|  | def f(x): | 
|  | a = x.sum(dim=1) | 
|  | b = x.sum(dim=1) | 
|  | c = x.sum() | 
|  | d = x.sum() | 
|  | return a + b + c + d | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 2) | 
|  |  | 
|  | def test_immutable_list_multiple_entries(self): | 
|  | def f(x): | 
|  | a = x.sum(dim=[0, 1]) | 
|  | b = x.sum(dim=[0, 1]) | 
|  | c = x.sum(dim=1) | 
|  | d = x.sum(dim=1) | 
|  | return a + b + c + d | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 2) | 
|  |  | 
|  | def test_simple(self): | 
|  | def f(x): | 
|  | a = x.cos() | 
|  | b = x.cos() | 
|  | c = a + a | 
|  | d = b + b | 
|  | return c + d | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 2) | 
|  |  | 
|  | def test_simple_2(self): | 
|  | def f(x): | 
|  | a = x.cos().sin() | 
|  | b = x.cos().sin() | 
|  | c = a + a | 
|  | d = b + b | 
|  | return c + d | 
|  | t = torch.randn(1) | 
|  | check(self, f, t, 3) | 
|  |  | 
|  | def test_two_args_default(self): | 
|  | def f(x): | 
|  | a = x.sum(dim=1) | 
|  | b = x.sum(dim=1, keepdim=False) | 
|  | c = x.sum(dim=1, keepdim=False) | 
|  | d = x.sum(dim=1) | 
|  | return a + b + c + d | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 3) | 
|  |  | 
|  | def test_two_args(self): | 
|  | def f(x): | 
|  | a = x.sum(dim=1) | 
|  | b = x.sum(dim=1, keepdim=True) | 
|  | c = x.sum(dim=1, keepdim=True) | 
|  | d = x.sum(dim=1) | 
|  | return a + b + c + d | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 2) | 
|  |  | 
|  | def test_simple_multiple_same_ops(self): | 
|  | def f(x): | 
|  | a = x.sum() | 
|  | b = x.sum() | 
|  | c = x.sum() | 
|  | d = x.sum() | 
|  | return a + b + c + d | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 3) | 
|  |  | 
|  | def test_nested_immutable_list_type(self): | 
|  | def f(x): | 
|  | a = torch.cat((x, x)) | 
|  | b = torch.cat((x, x)) | 
|  | return a + b | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 1) | 
|  |  | 
|  | def test_kwarg(self): | 
|  | def f(x): | 
|  | a = torch.ones_like(x) | 
|  | b = torch.ones_like(x) | 
|  | return a + b | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 1) | 
|  |  | 
|  | """ | 
|  | Generate function with random ops and check if the result is the same | 
|  | """ | 
|  | def test_random(self): | 
|  | def f(x): | 
|  | vals = [x] | 
|  | ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu] | 
|  | for _ in range(100): | 
|  | new_val = random.choice(ops)(random.choice(vals)) | 
|  | vals.append(new_val) | 
|  | return vals[-1] | 
|  |  | 
|  | fx_g = symbolic_trace(f) | 
|  | fx_g.graph.eliminate_dead_code() | 
|  | fx_g.recompile() | 
|  | t = torch.randn(2, 2) | 
|  |  | 
|  | for _ in range(30): | 
|  | check(self, fx_g, t, -1, graph_input=True) | 
|  |  | 
|  | """ | 
|  | Test that banned list ban ops as expected. | 
|  | """ | 
|  | def test_banned_list(self): | 
|  | def f(x): | 
|  | a = x + 1 | 
|  | b = x + 1 | 
|  | return a + b | 
|  |  | 
|  | t = torch.randn(2, 2) | 
|  | P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add]) | 
|  | check(self, f, t, 0, P=P_ban_add)  # check that add is banned | 
|  | check(self, f, t, 1)  # check that add is not banned by default | 
|  |  | 
|  | def test_rand_like(self): | 
|  | def f(x): | 
|  | a = torch.rand_like(x) | 
|  | b = torch.rand_like(x) | 
|  | return a + b | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 0, check_val=False) | 
|  |  | 
|  | def test_rand_n(self): | 
|  | def f(x): | 
|  | a = torch.randn(4) | 
|  | b = torch.randn(4) | 
|  | return a + b | 
|  | t = torch.randn(2, 2) | 
|  | check(self, f, t, 0, check_val=False) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |