blob: 5205845f315eea75a5c30a0a5e60ce64039806f7 [file] [log] [blame]
import torch
from torch.testing._internal.jit_utils import JitTestCase
import operator
from torch.testing import FileCheck
from typing import List
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# XXX: still in prototype
class TestSymbolicShapeAnalysis(JitTestCase):
def setUp(self):
self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
torch._C._jit_set_symbolic_shapes_test_mode(True)
def tearDown(self):
torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled)
def test_shape_analysis(self):
@torch.jit.script
def foo(x, y):
return x * y
inputs = list(foo.graph.inputs())
def prop_shapes_on_graph(inp0, inp1):
inputs[0].setType(inputs[0].type().with_sizes(inp0))
inputs[1].setType(inputs[1].type().with_sizes(inp1))
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5])
FileCheck().check("1, 7, 6, 5").run(foo.graph)
# None implicitly creates a new symbolic symbol
prop_shapes_on_graph([None, None], [None, None, None])
output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes()
inp0_shape = inputs[0].type().symbolic_sizes()
inp1_shape = inputs[1].type().symbolic_sizes()
# output shape dim 0 should be taken from the second inp dim0
# other two dims we cannot infer and are given a new symbolic shape
self.assertEqual(output_shape[0], inp1_shape[0])
self.assertFalse(output_shape[1] in inp0_shape + inp1_shape)
self.assertFalse(output_shape[2] in inp0_shape + inp1_shape)
# XXX: symbolic shapes are represented with an increasing counter of unique
# values, use `_new_symbolic_shape_symbol` api instead of specifying negative
# dimensions directly so there is no chance of collision between manual number
# and current counter value.
sym1 = torch._C._new_symbolic_shape_symbol()
sym2 = torch._C._new_symbolic_shape_symbol()
sym3 = torch._C._new_symbolic_shape_symbol()
prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3])
output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes()
self.assertEqual(output_shape[0], sym1)
self.assertEqual(output_shape[1], sym2)
self.assertEqual(output_shape[2], sym3)
def test_sharing_of_list_len(self):
@torch.jit.script
def foo(x, out: List[int]):
return torch.nn.functional.adaptive_avg_pool2d(x, out)
self.run_pass("inline", foo.graph)
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
FileCheck().check("Tensor(*, *)").check_same("adaptive_avg_pool2d").run(foo.graph)
def test_shared_shape_graph(self):
@torch.jit.script
def foo(x, y):
return x * y, x / y
mul_node = foo.graph.findNode("aten::mul")
div_node = foo.graph.findNode("aten::div")
mul_graph = torch._C._jit_shape_compute_graph_for_node(mul_node)
div_graph = torch._C._jit_shape_compute_graph_for_node(div_node)
self.assertIsNotNone(mul_graph)
self.assertIs(mul_graph, div_graph)
def test_unary_shape_functions(self):
def apply(fn):
return lambda x: fn(x)
unary_ops = [
torch.nn.functional.hardtanh,
]
for fn in unary_ops:
t = torch.jit.trace(fn, (torch.rand([4, 4])))
ten_input = next(t.graph.inputs())
ten_input.setType(ten_input.type().with_sizes([2, 2]))
torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2])
def test_binary_shape_functions(self):
def apply(fn):
return lambda x, y: fn(x, y)
binary_ops = [
operator.__mul__,
operator.__truediv__,
operator.__gt__,
operator.__add__,
]
for fn in binary_ops:
size_1 = [1, 4, 8]
size_2 = [4, 1, 8]
t = torch.jit.trace(fn, (torch.rand([4]), torch.rand([4])))
inputs = list(t.graph.inputs())
inputs[0].setType(inputs[0].type().with_sizes(size_1))
inputs[1].setType(inputs[1].type().with_sizes(size_2))
torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8])