| import torch |
| from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything |
| from torch import nn |
| from torch.testing import FileCheck |
| from typing import List |
| |
| import unittest |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| class TestPeephole(JitTestCase): |
| def test_peephole_with_writes(self): |
| def test_write(x): |
| s = 0 |
| s += x |
| s += x |
| return s |
| |
| self.checkScript(test_write, (torch.ones(4, 4),)) |
| |
| def test_peephole_with_non_output_writes(self): |
| @torch.jit.ignore |
| def nomnom(x): |
| pass |
| |
| def test_write(x): |
| t = torch.ones_like(x) |
| z = x.clone() |
| y = z + 0 |
| z.add_(t) |
| # this makes sure z isn't blasted out of existence |
| # because it isn't returned or used in a side-effectful |
| # way |
| nomnom(z) |
| return y + y |
| |
| a = torch.ones(4, 4) |
| j = self.checkScript(test_write, (a,)) |
| |
| def test_peephole_no_output_aliasing(self): |
| def test_peephole(x): |
| y = x + 0 |
| return x, y |
| |
| a = torch.ones(4, 4) |
| j = self.checkScript(test_peephole, (a,)) |
| r1, r2 = j(a) |
| self.assertNotEqual(r1.data_ptr(), r2.data_ptr()) |
| |
| def test_peephole(self): |
| a = torch.tensor([0.4]) |
| b = torch.tensor([0.7]) |
| c = torch.tensor([0], dtype=torch.int32) |
| |
| def f(x, y): |
| return x.type_as(y) |
| |
| tf = torch.jit.trace(f, (a, b)) |
| FileCheck().check("type_as").run(str(tf.graph)) |
| self.run_pass('peephole', tf.graph) |
| FileCheck().check_not("type_as").run(str(tf.graph)) |
| tf2 = torch.jit.trace(f, (a, c)) |
| s = str(tf2.graph) |
| self.run_pass('peephole', tf2.graph) |
| self.assertEqual(s, str(s)) |
| |
| def test_peephole_dynamic(self): |
| def f(x, y): |
| return x.type_as(y) |
| |
| fn = torch.jit.script(f) |
| s = str(fn.graph) |
| torch._C._jit_pass_peephole(fn.graph) |
| self.assertEqual(s, str(fn.graph)) |
| |
| def test_peephole_list_ops(self): |
| @torch.jit.script |
| def foo(x, y, z): |
| return len([x, y, z]) |
| |
| self.run_pass('peephole', foo.graph) |
| FileCheck().check("value=3").check_next("return").run(foo.graph) |
| |
| @torch.jit.script |
| def foo(x, y, z): |
| li = [x, y, z] |
| for i in range(len(x)): |
| li.append(x) |
| return len([x, y, z]) |
| |
| self.run_pass('peephole', foo.graph) |
| FileCheck().check_not("aten::len").run(foo.graph) |
| |
| @torch.jit.script |
| def foo(x, y, z): |
| li = [x, y, z] |
| return li[1], li[-2] |
| |
| FileCheck().check("aten::__getitem__").run(foo.graph) |
| self.run_pass('peephole', foo.graph) |
| FileCheck().check_not("aten::__getitem__").run(foo.graph) |
| |
| @torch.jit.script |
| def foo(x, y, z): |
| li = [x, y, z] |
| return li[-7] |
| |
| self.run_pass('peephole', foo.graph) |
| FileCheck().check("aten::__getitem__").run(foo.graph) |
| |
| @torch.jit.script |
| def foo(x, y, z): |
| li = [x, y, z] |
| for i in range(len(x)): |
| li.append(x) |
| return li[-2] |
| |
| self.run_pass('peephole', foo.graph) |
| FileCheck().check("aten::__getitem__").run(foo.graph) |
| |
| @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") |
| def test_peephole_cuda(self): |
| a = torch.tensor([0.4], device='cpu') |
| b = torch.tensor([0.7], device='cuda') |
| c = torch.tensor([0.7], device='cuda') |
| |
| def f(x, y): |
| return x.type_as(y) |
| |
| trace = torch.jit.trace(f, (a, c)) |
| s = str(trace.graph) |
| self.run_pass('peephole', trace.graph) |
| self.assertEqual(s, str(trace.graph)) |
| trace = torch.jit.trace(f, (b, c)) |
| self.run_pass('peephole', trace.graph) |
| self.run_pass('dce', trace.graph) |
| FileCheck().check_not("type_as").run(str(trace.graph)) |
| |
| @_inline_everything |
| def test_peephole_type_refinements(self): |
| def refine(x): |
| # type: (Optional[Tensor]) -> Tensor |
| return x if x is not None else torch.tensor(3) |
| |
| @torch.jit.script |
| def test(): |
| return refine(torch.tensor(4)) |
| |
| FileCheck().check("prim::unchecked_cast").run(test.graph) |
| self.run_pass('peephole', test.graph) |
| FileCheck().check_not("prim::unchecked_cast").run(test.graph) |
| |
| # refinement not optimzied out |
| def is_int_tensor(x): |
| scalar = x.item() |
| if isinstance(scalar, int): |
| return scalar + 3 |
| else: |
| return 8 |
| |
| self.checkScript(is_int_tensor, (torch.tensor(2),)) |
| self.checkScript(is_int_tensor, (torch.tensor(2.5),)) |
| graph = torch.jit.script(is_int_tensor).graph |
| self.run_pass('peephole', graph) |
| FileCheck().check("prim::unchecked_cast").run(graph) |
| |
| def test_short_circuit_optimization(self): |
| @torch.jit.script |
| def const_expressions(x): |
| # type: (int) -> Tuple[bool, bool] |
| return x == 1 and False, x == 1 or True |
| self.run_pass('constant_propagation', const_expressions.graph) |
| FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph) |
| self.assertEqual(const_expressions(1), (False, True)) |
| |
| @torch.jit.script |
| def redundant_expressions(x): |
| # type: (int) -> Tuple[bool, bool] |
| return x == 1 and True, x == 1 or False |
| |
| self.run_pass('peephole', redundant_expressions.graph) |
| self.assertEqual(redundant_expressions(1), (True, True)) |
| self.assertEqual(redundant_expressions(0), (False, False)) |
| # and True / or False are removed from graph |
| FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph) |
| |
| def test_conv_dim_folding(self): |
| modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] |
| for mod in modules: |
| class ConvDim(torch.nn.Module): |
| def __init__(self): |
| super(ConvDim, self).__init__() |
| self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x.dim() |
| |
| conv_dim = torch.jit.script(ConvDim()) |
| self.run_pass("inline", conv_dim.graph) |
| self.run_pass("peephole", conv_dim.graph) |
| FileCheck().check_not("conv").check_not("dim").run(conv_dim.graph) |
| |
| class ConvDimMutate(torch.nn.Module): |
| def __init__(self): |
| super(ConvDimMutate, self).__init__() |
| self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x.resize_([4, 4]) |
| return x.dim() |
| |
| conv_dim = torch.jit.script(ConvDimMutate()) |
| self.run_pass("inline", conv_dim.graph) |
| self.run_pass("peephole", conv_dim.graph) |
| FileCheck().check("conv").check("dim").run(conv_dim.graph) |
| |
| def test_normalized_is_op(self): |
| def convertible_is_op(x: bool, y: bool): |
| return x is True, False is x, x is y |
| |
| self.checkScript(convertible_is_op, (True, False)) |
| |
| op_graph = torch.jit.script(convertible_is_op).graph |
| FileCheck().check_count("aten::eq", 3, exactly=True).run(op_graph) |
| FileCheck().check_count("aten::__is__", 0, exactly=True).run(op_graph) |
| |
| def test_normalized_isnot_op(self): |
| def convertible_isnot_op(x: bool, y: bool): |
| return x is not True, False is not x, x is not y |
| |
| self.checkScript(convertible_isnot_op, (True, False)) |
| |
| op_graph = torch.jit.script(convertible_isnot_op).graph |
| FileCheck().check_count("aten::ne", 3, exactly=True).run(op_graph) |
| FileCheck().check_count("aten::__isnot__", 0, exactly=True).run(op_graph) |
| |
| def test_peephole_list_len(self): |
| def run_peephole_and_check_const_value(graph, const_string): |
| torch._C._jit_pass_peephole_list_idioms(graph, refine_list_len=True) |
| self.run_pass("constant_propagation", graph) |
| FileCheck().check(const_string).check_next("return").run(graph) |
| |
| def gen_li(inp_len: int): |
| return [0 for i in range(inp_len)] |
| |
| @torch.jit.script |
| def foo(x: List[int], y: List[int]): |
| if len(x) != 4 or len(y) != 5: |
| raise Exception("") |
| |
| return len(x) + len(y) |
| |
| run_peephole_and_check_const_value(foo.graph, "value=9") |
| self.assertEqual(foo(gen_li(4), gen_li(5)), 9) |
| with self.assertRaises(Exception): |
| foo(2, 4) |
| |
| @torch.jit.script |
| def foo(x: List[int], y: List[int]): |
| if len(x) == 4 and len(y) == 5: |
| pass |
| else: |
| raise Exception("hi") |
| |
| return len(x) + len(y) |
| |
| run_peephole_and_check_const_value(foo.graph, "value=9") |
| self.assertEqual(foo(gen_li(4), gen_li(5)), 9) |
| with self.assertRaises(Exception): |
| foo(2, 4) |
| |
| @torch.jit.script |
| def foo(x: List[int], y: List[int], z: List[int]): |
| if len(x) != 4: |
| raise Exception("..") |
| else: |
| if len(y) != 8: |
| raise Exception("...") |
| else: |
| if len(z) == 3: |
| pass |
| else: |
| raise Exception("...") |
| |
| return len(x) + len(y) * len(z) |
| |
| run_peephole_and_check_const_value(foo.graph, "value=28") |
| self.assertEqual(foo(gen_li(4), gen_li(8), gen_li(3)), 28) |
| with self.assertRaises(Exception): |
| foo(1, 2, 3) |
| |
| # refinement should persist in second len(x) call |
| |
| @torch.jit.script |
| def foo(x: List[int], cond: bool): |
| if len(x) == 4: |
| if cond: |
| return len(x) |
| return 4 |
| |
| return 4 |
| |
| run_peephole_and_check_const_value(foo.graph, "value=4") |
| |
| def test_const_tuple_output(graph, const_inputs): |
| tup = graph.findNode("prim::TupleConstruct") |
| for i, elem in enumerate(tup.inputs()): |
| if i in const_inputs: |
| self.assertIsNotNone(elem.toIValue()) |
| else: |
| self.assertIsNone(elem.toIValue()) |
| |
| # testing combinations of x1 : {True, False} x |
| # {then/else branch} x assert {True/False} |
| |
| @torch.jit.script |
| def foo(x: List[int], b: List[int]): |
| if len(x) == 5: |
| x1 = True |
| else: |
| x1 = len(b) != 4 |
| assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq |
| return len(x), len(b) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| torch._C._jit_pass_constant_propagation(foo.graph) |
| # we can only infer len(b) == 4 here |
| test_const_tuple_output(foo.graph, [1]) |
| |
| @torch.jit.script |
| def foo(x: List[int], b: List[int]): |
| if len(x) == 5: |
| x1 = False |
| else: |
| x1 = len(b) != 4 |
| assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq |
| return len(x), len(b) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| torch._C._jit_pass_constant_propagation(foo.graph) |
| # cant infer anything |
| test_const_tuple_output(foo.graph, []) |
| |
| @torch.jit.script |
| def foo(x: List[int], b: List[int]): |
| if len(x) == 5: |
| x1 = True |
| else: |
| x1 = len(b) == 4 |
| assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq |
| return len(x), len(b) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| torch._C._jit_pass_constant_propagation(foo.graph) |
| # we cant infer anything, only len(b) != 4 |
| test_const_tuple_output(foo.graph, []) |
| |
| @torch.jit.script |
| def foo(x: List[int], b: List[int]): |
| if len(x) == 5: |
| x1 = True |
| else: |
| x1 = len(b) != 4 |
| assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq |
| return len(x), len(b) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| torch._C._jit_pass_constant_propagation(foo.graph) |
| # can infer len(b) == 4 |
| test_const_tuple_output(foo.graph, [1]) |
| |
| # swap branches |
| @torch.jit.script |
| def foo(x: List[int], b: List[int]): |
| if len(x) != 5: |
| x1 = len(b) != 4 |
| else: |
| x1 = True |
| assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq |
| return len(x), len(b) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| torch._C._jit_pass_constant_propagation(foo.graph) |
| # can infer len(b) == 4 |
| test_const_tuple_output(foo.graph, [1]) |
| |
| # use __not__ |
| @torch.jit.script |
| def foo(x: List[int], b: List[int]): |
| if len(x) != 5: |
| x1 = len(b) != 4 |
| else: |
| x1 = True |
| assert not x1 |
| return len(x), len(b) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| torch._C._jit_pass_constant_propagation(foo.graph) |
| # can infer len(b) == 4 |
| test_const_tuple_output(foo.graph, [1]) |
| |
| # Test unsuccessful optimizations |
| |
| @torch.jit.script |
| def foo(x: List[int]): |
| assert len(x) == 4 |
| x.append(3) |
| return len(x) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| self.run_pass("constant_propagation", foo.graph) |
| FileCheck().check_count("aten::len", 2).run(foo.graph) |
| |
| @torch.jit.script |
| def foo(x: List[int], y: List[int]): |
| assert len(x) == 4 or len(y) == 5 |
| return len(x) + len(y) |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| self.run_pass("constant_propagation", foo.graph) |
| FileCheck().check_count("aten::len", 4).run(foo.graph) |
| |
| def test_integer_refinement(self): |
| def run_peephole_and_check_const_value(graph, const_string): |
| self.run_pass("refine_integer_values", graph) |
| self.run_pass("constant_propagation", graph) |
| self.run_pass("dce", graph) |
| FileCheck().check(const_string).check_next("return").run(graph) |
| |
| @torch.jit.script |
| def foo(x: int, y: int): |
| if x != 4 or y != 5: |
| raise Exception("") |
| |
| return x + y |
| |
| graph = foo.graph |
| self.run_pass("refine_integer_values", graph) |
| self.run_pass("constant_propagation", graph) |
| self.run_pass("dce", graph) |
| |
| run_peephole_and_check_const_value(foo.graph, "value=9") |
| self.assertEqual(foo(4, 5), 9) |
| with self.assertRaises(Exception): |
| foo(2, 4) |
| |
| @torch.jit.script |
| def foo(x: int, y: int): |
| if x == 4 and y == 5: |
| pass |
| else: |
| raise Exception("hi") |
| |
| return x + y |
| |
| run_peephole_and_check_const_value(foo.graph, "value=9") |
| self.assertEqual(foo(4, 5), 9) |
| with self.assertRaises(Exception): |
| foo(2, 4) |
| |
| @torch.jit.script |
| def foo(x: int, y: int, z: int): |
| if x != 4: |
| raise Exception("..") |
| else: |
| if y != 8: |
| raise Exception("...") |
| else: |
| if z == 3: |
| pass |
| else: |
| raise Exception("...") |
| |
| return x + y * z |
| |
| run_peephole_and_check_const_value(foo.graph, "value=28") |
| self.assertEqual(foo(4, 8, 3), 28) |
| with self.assertRaises(Exception): |
| foo(1, 2, 3) |
| |
| # refinement should persist in second len(x) call |
| |
| @torch.jit.script |
| def foo(x: int, cond: bool): |
| if x == 4: |
| if cond: |
| return x |
| return 4 |
| |
| return 4 |
| |
| run_peephole_and_check_const_value(foo.graph, "value=4") |
| |
| @torch.jit.script |
| def foo(x: int, y: int): |
| assert x == 4 or y == 5 |
| return x + y |
| |
| torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) |
| self.run_pass("constant_propagation", foo.graph) |
| FileCheck().check("aten::add").run(foo.graph) |
| |
| def test_optimize_out_comparison_same_value(self): |
| @torch.jit.script |
| def foo(x: int): |
| return x == x, x != x |
| |
| self.run_pass("peephole", foo.graph) |
| FileCheck().check_not("aten::eq").check_not("aten::neq").run(foo.graph) |
| self.assertEqual(foo(1), (True, False)) |
| |
| def test_refine_integer_values(self): |
| @torch.jit.script |
| def foo(x: int): |
| y = 1 |
| if x == 1: |
| return y |
| else: |
| return x |
| |
| self.run_pass("refine_integer_values", foo.graph) |
| self.run_pass("constant_propagation", foo.graph) |
| self.run_pass("dce", foo.graph) |
| FileCheck().check("graph").check_next("return").run(foo.graph) |
| self.assertEqual(foo(2), 2) |
| self.assertEqual(foo(1), 1) |
| |
| def test_peephole_len_list(self): |
| @torch.jit.script |
| def foo(x): |
| return len(x.size()) |
| |
| self.run_pass("peephole", foo.graph) |
| FileCheck().check("aten::len").run(foo.graph) |
| inputs = list(foo.graph.inputs()) |
| inputs[0].setType(inputs[0].type().with_sizes([None, None])) |
| self.run_pass("peephole", foo.graph) |
| FileCheck().check_not("aten::len").run(foo.graph) |
| self.assertEqual(2, foo(torch.rand([3, 1]))) |
| |
| @torch.jit.script |
| def foo(x): |
| li = x.size() |
| li.append(4) |
| return len(li) |
| |
| inputs = list(foo.graph.inputs()) |
| inputs[0].setType(inputs[0].type().with_sizes([None, None])) |
| self.run_pass("peephole", foo.graph) |
| FileCheck().check("aten::len").run(foo.graph) |
| self.assertEqual(3, foo(torch.rand([3, 1]))) |