blob: f085541644fe3917a208a1872efc2e6b13802417 [file] [log] [blame]
import torch
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
from torch import nn
from torch.testing import FileCheck
from typing import List
import unittest
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestPeephole(JitTestCase):
def test_peephole_with_writes(self):
def test_write(x):
s = 0
s += x
s += x
return s
self.checkScript(test_write, (torch.ones(4, 4),))
def test_peephole_with_non_output_writes(self):
@torch.jit.ignore
def nomnom(x):
pass
def test_write(x):
t = torch.ones_like(x)
z = x.clone()
y = z + 0
z.add_(t)
# this makes sure z isn't blasted out of existence
# because it isn't returned or used in a side-effectful
# way
nomnom(z)
return y + y
a = torch.ones(4, 4)
j = self.checkScript(test_write, (a,))
def test_peephole_no_output_aliasing(self):
def test_peephole(x):
y = x + 0
return x, y
a = torch.ones(4, 4)
j = self.checkScript(test_peephole, (a,))
r1, r2 = j(a)
self.assertNotEqual(r1.data_ptr(), r2.data_ptr())
def test_peephole(self):
a = torch.tensor([0.4])
b = torch.tensor([0.7])
c = torch.tensor([0], dtype=torch.int32)
def f(x, y):
return x.type_as(y)
tf = torch.jit.trace(f, (a, b))
FileCheck().check("type_as").run(str(tf.graph))
self.run_pass('peephole', tf.graph)
FileCheck().check_not("type_as").run(str(tf.graph))
tf2 = torch.jit.trace(f, (a, c))
s = str(tf2.graph)
self.run_pass('peephole', tf2.graph)
self.assertEqual(s, str(s))
def test_peephole_dynamic(self):
def f(x, y):
return x.type_as(y)
fn = torch.jit.script(f)
s = str(fn.graph)
torch._C._jit_pass_peephole(fn.graph)
self.assertEqual(s, str(fn.graph))
def test_peephole_list_ops(self):
@torch.jit.script
def foo(x, y, z):
return len([x, y, z])
self.run_pass('peephole', foo.graph)
FileCheck().check("value=3").check_next("return").run(foo.graph)
@torch.jit.script
def foo(x, y, z):
li = [x, y, z]
for i in range(len(x)):
li.append(x)
return len([x, y, z])
self.run_pass('peephole', foo.graph)
FileCheck().check_not("aten::len").run(foo.graph)
@torch.jit.script
def foo(x, y, z):
li = [x, y, z]
return li[1], li[-2]
FileCheck().check("aten::__getitem__").run(foo.graph)
self.run_pass('peephole', foo.graph)
FileCheck().check_not("aten::__getitem__").run(foo.graph)
@torch.jit.script
def foo(x, y, z):
li = [x, y, z]
return li[-7]
self.run_pass('peephole', foo.graph)
FileCheck().check("aten::__getitem__").run(foo.graph)
@torch.jit.script
def foo(x, y, z):
li = [x, y, z]
for i in range(len(x)):
li.append(x)
return li[-2]
self.run_pass('peephole', foo.graph)
FileCheck().check("aten::__getitem__").run(foo.graph)
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
def test_peephole_cuda(self):
a = torch.tensor([0.4], device='cpu')
b = torch.tensor([0.7], device='cuda')
c = torch.tensor([0.7], device='cuda')
def f(x, y):
return x.type_as(y)
trace = torch.jit.trace(f, (a, c))
s = str(trace.graph)
self.run_pass('peephole', trace.graph)
self.assertEqual(s, str(trace.graph))
trace = torch.jit.trace(f, (b, c))
self.run_pass('peephole', trace.graph)
self.run_pass('dce', trace.graph)
FileCheck().check_not("type_as").run(str(trace.graph))
@_inline_everything
def test_peephole_type_refinements(self):
def refine(x):
# type: (Optional[Tensor]) -> Tensor
return x if x is not None else torch.tensor(3)
@torch.jit.script
def test():
return refine(torch.tensor(4))
FileCheck().check("prim::unchecked_cast").run(test.graph)
self.run_pass('peephole', test.graph)
FileCheck().check_not("prim::unchecked_cast").run(test.graph)
# refinement not optimzied out
def is_int_tensor(x):
scalar = x.item()
if isinstance(scalar, int):
return scalar + 3
else:
return 8
self.checkScript(is_int_tensor, (torch.tensor(2),))
self.checkScript(is_int_tensor, (torch.tensor(2.5),))
graph = torch.jit.script(is_int_tensor).graph
self.run_pass('peephole', graph)
FileCheck().check("prim::unchecked_cast").run(graph)
def test_short_circuit_optimization(self):
@torch.jit.script
def const_expressions(x):
# type: (int) -> Tuple[bool, bool]
return x == 1 and False, x == 1 or True
self.run_pass('constant_propagation', const_expressions.graph)
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
self.assertEqual(const_expressions(1), (False, True))
@torch.jit.script
def redundant_expressions(x):
# type: (int) -> Tuple[bool, bool]
return x == 1 and True, x == 1 or False
self.run_pass('peephole', redundant_expressions.graph)
self.assertEqual(redundant_expressions(1), (True, True))
self.assertEqual(redundant_expressions(0), (False, False))
# and True / or False are removed from graph
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
def test_conv_dim_folding(self):
modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
for mod in modules:
class ConvDim(torch.nn.Module):
def __init__(self):
super(ConvDim, self).__init__()
self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False)
def forward(self, x):
x = self.conv(x)
return x.dim()
conv_dim = torch.jit.script(ConvDim())
self.run_pass("inline", conv_dim.graph)
self.run_pass("peephole", conv_dim.graph)
FileCheck().check_not("conv").check_not("dim").run(conv_dim.graph)
class ConvDimMutate(torch.nn.Module):
def __init__(self):
super(ConvDimMutate, self).__init__()
self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False)
def forward(self, x):
x = self.conv(x)
x.resize_([4, 4])
return x.dim()
conv_dim = torch.jit.script(ConvDimMutate())
self.run_pass("inline", conv_dim.graph)
self.run_pass("peephole", conv_dim.graph)
FileCheck().check("conv").check("dim").run(conv_dim.graph)
def test_normalized_is_op(self):
def convertible_is_op(x: bool, y: bool):
return x is True, False is x, x is y
self.checkScript(convertible_is_op, (True, False))
op_graph = torch.jit.script(convertible_is_op).graph
FileCheck().check_count("aten::eq", 3, exactly=True).run(op_graph)
FileCheck().check_count("aten::__is__", 0, exactly=True).run(op_graph)
def test_normalized_isnot_op(self):
def convertible_isnot_op(x: bool, y: bool):
return x is not True, False is not x, x is not y
self.checkScript(convertible_isnot_op, (True, False))
op_graph = torch.jit.script(convertible_isnot_op).graph
FileCheck().check_count("aten::ne", 3, exactly=True).run(op_graph)
FileCheck().check_count("aten::__isnot__", 0, exactly=True).run(op_graph)
def test_peephole_list_len(self):
def run_peephole_and_check_const_value(graph, const_string):
torch._C._jit_pass_peephole_list_idioms(graph, refine_list_len=True)
self.run_pass("constant_propagation", graph)
FileCheck().check(const_string).check_next("return").run(graph)
def gen_li(inp_len: int):
return [0 for i in range(inp_len)]
@torch.jit.script
def foo(x: List[int], y: List[int]):
if len(x) != 4 or len(y) != 5:
raise Exception("")
return len(x) + len(y)
run_peephole_and_check_const_value(foo.graph, "value=9")
self.assertEqual(foo(gen_li(4), gen_li(5)), 9)
with self.assertRaises(Exception):
foo(2, 4)
@torch.jit.script
def foo(x: List[int], y: List[int]):
if len(x) == 4 and len(y) == 5:
pass
else:
raise Exception("hi")
return len(x) + len(y)
run_peephole_and_check_const_value(foo.graph, "value=9")
self.assertEqual(foo(gen_li(4), gen_li(5)), 9)
with self.assertRaises(Exception):
foo(2, 4)
@torch.jit.script
def foo(x: List[int], y: List[int], z: List[int]):
if len(x) != 4:
raise Exception("..")
else:
if len(y) != 8:
raise Exception("...")
else:
if len(z) == 3:
pass
else:
raise Exception("...")
return len(x) + len(y) * len(z)
run_peephole_and_check_const_value(foo.graph, "value=28")
self.assertEqual(foo(gen_li(4), gen_li(8), gen_li(3)), 28)
with self.assertRaises(Exception):
foo(1, 2, 3)
# refinement should persist in second len(x) call
@torch.jit.script
def foo(x: List[int], cond: bool):
if len(x) == 4:
if cond:
return len(x)
return 4
return 4
run_peephole_and_check_const_value(foo.graph, "value=4")
def test_const_tuple_output(graph, const_inputs):
tup = graph.findNode("prim::TupleConstruct")
for i, elem in enumerate(tup.inputs()):
if i in const_inputs:
self.assertIsNotNone(elem.toIValue())
else:
self.assertIsNone(elem.toIValue())
# testing combinations of x1 : {True, False} x
# {then/else branch} x assert {True/False}
@torch.jit.script
def foo(x: List[int], b: List[int]):
if len(x) == 5:
x1 = True
else:
x1 = len(b) != 4
assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq
return len(x), len(b)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
torch._C._jit_pass_constant_propagation(foo.graph)
# we can only infer len(b) == 4 here
test_const_tuple_output(foo.graph, [1])
@torch.jit.script
def foo(x: List[int], b: List[int]):
if len(x) == 5:
x1 = False
else:
x1 = len(b) != 4
assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq
return len(x), len(b)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
torch._C._jit_pass_constant_propagation(foo.graph)
# cant infer anything
test_const_tuple_output(foo.graph, [])
@torch.jit.script
def foo(x: List[int], b: List[int]):
if len(x) == 5:
x1 = True
else:
x1 = len(b) == 4
assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq
return len(x), len(b)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
torch._C._jit_pass_constant_propagation(foo.graph)
# we cant infer anything, only len(b) != 4
test_const_tuple_output(foo.graph, [])
@torch.jit.script
def foo(x: List[int], b: List[int]):
if len(x) == 5:
x1 = True
else:
x1 = len(b) != 4
assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq
return len(x), len(b)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
torch._C._jit_pass_constant_propagation(foo.graph)
# can infer len(b) == 4
test_const_tuple_output(foo.graph, [1])
# swap branches
@torch.jit.script
def foo(x: List[int], b: List[int]):
if len(x) != 5:
x1 = len(b) != 4
else:
x1 = True
assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq
return len(x), len(b)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
torch._C._jit_pass_constant_propagation(foo.graph)
# can infer len(b) == 4
test_const_tuple_output(foo.graph, [1])
# use __not__
@torch.jit.script
def foo(x: List[int], b: List[int]):
if len(x) != 5:
x1 = len(b) != 4
else:
x1 = True
assert not x1
return len(x), len(b)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
torch._C._jit_pass_constant_propagation(foo.graph)
# can infer len(b) == 4
test_const_tuple_output(foo.graph, [1])
# Test unsuccessful optimizations
@torch.jit.script
def foo(x: List[int]):
assert len(x) == 4
x.append(3)
return len(x)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
self.run_pass("constant_propagation", foo.graph)
FileCheck().check_count("aten::len", 2).run(foo.graph)
@torch.jit.script
def foo(x: List[int], y: List[int]):
assert len(x) == 4 or len(y) == 5
return len(x) + len(y)
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
self.run_pass("constant_propagation", foo.graph)
FileCheck().check_count("aten::len", 4).run(foo.graph)
def test_integer_refinement(self):
def run_peephole_and_check_const_value(graph, const_string):
self.run_pass("refine_integer_values", graph)
self.run_pass("constant_propagation", graph)
self.run_pass("dce", graph)
FileCheck().check(const_string).check_next("return").run(graph)
@torch.jit.script
def foo(x: int, y: int):
if x != 4 or y != 5:
raise Exception("")
return x + y
graph = foo.graph
self.run_pass("refine_integer_values", graph)
self.run_pass("constant_propagation", graph)
self.run_pass("dce", graph)
run_peephole_and_check_const_value(foo.graph, "value=9")
self.assertEqual(foo(4, 5), 9)
with self.assertRaises(Exception):
foo(2, 4)
@torch.jit.script
def foo(x: int, y: int):
if x == 4 and y == 5:
pass
else:
raise Exception("hi")
return x + y
run_peephole_and_check_const_value(foo.graph, "value=9")
self.assertEqual(foo(4, 5), 9)
with self.assertRaises(Exception):
foo(2, 4)
@torch.jit.script
def foo(x: int, y: int, z: int):
if x != 4:
raise Exception("..")
else:
if y != 8:
raise Exception("...")
else:
if z == 3:
pass
else:
raise Exception("...")
return x + y * z
run_peephole_and_check_const_value(foo.graph, "value=28")
self.assertEqual(foo(4, 8, 3), 28)
with self.assertRaises(Exception):
foo(1, 2, 3)
# refinement should persist in second len(x) call
@torch.jit.script
def foo(x: int, cond: bool):
if x == 4:
if cond:
return x
return 4
return 4
run_peephole_and_check_const_value(foo.graph, "value=4")
@torch.jit.script
def foo(x: int, y: int):
assert x == 4 or y == 5
return x + y
torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
self.run_pass("constant_propagation", foo.graph)
FileCheck().check("aten::add").run(foo.graph)
def test_optimize_out_comparison_same_value(self):
@torch.jit.script
def foo(x: int):
return x == x, x != x
self.run_pass("peephole", foo.graph)
FileCheck().check_not("aten::eq").check_not("aten::neq").run(foo.graph)
self.assertEqual(foo(1), (True, False))
def test_refine_integer_values(self):
@torch.jit.script
def foo(x: int):
y = 1
if x == 1:
return y
else:
return x
self.run_pass("refine_integer_values", foo.graph)
self.run_pass("constant_propagation", foo.graph)
self.run_pass("dce", foo.graph)
FileCheck().check("graph").check_next("return").run(foo.graph)
self.assertEqual(foo(2), 2)
self.assertEqual(foo(1), 1)
def test_peephole_len_list(self):
@torch.jit.script
def foo(x):
return len(x.size())
self.run_pass("peephole", foo.graph)
FileCheck().check("aten::len").run(foo.graph)
inputs = list(foo.graph.inputs())
inputs[0].setType(inputs[0].type().with_sizes([None, None]))
self.run_pass("peephole", foo.graph)
FileCheck().check_not("aten::len").run(foo.graph)
self.assertEqual(2, foo(torch.rand([3, 1])))
@torch.jit.script
def foo(x):
li = x.size()
li.append(4)
return len(li)
inputs = list(foo.graph.inputs())
inputs[0].setType(inputs[0].type().with_sizes([None, None]))
self.run_pass("peephole", foo.graph)
FileCheck().check("aten::len").run(foo.graph)
self.assertEqual(3, foo(torch.rand([3, 1])))