| # Owner(s): ["oncall: jit"] | 
 |  | 
 | import operator | 
 | import unittest | 
 | from textwrap import dedent | 
 |  | 
 | import torch | 
 | from torch import nn | 
 | from torch.testing import FileCheck | 
 | from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat | 
 | from torch.testing._internal.common_utils import make_tensor | 
 | from torch.testing._internal.jit_utils import JitTestCase, execWrapper | 
 | from typing import List, Any | 
 |  | 
 | if __name__ == '__main__': | 
 |     raise RuntimeError("This test file is not meant to be run directly, use:\n\n" | 
 |                        "\tpython test/test_jit.py TESTNAME\n\n" | 
 |                        "instead.") | 
 |  | 
 | # XXX: still in prototype | 
 | class TestSymbolicShapeAnalysis(JitTestCase): | 
 |     def setUp(self): | 
 |         self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() | 
 |         torch._C._jit_set_symbolic_shapes_test_mode(True) | 
 |  | 
 |     def tearDown(self): | 
 |         torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled) | 
 |  | 
 |     def test_shape_analysis(self): | 
 |         @torch.jit.script | 
 |         def foo(x, y): | 
 |             return x * y | 
 |  | 
 |         inputs = list(foo.graph.inputs()) | 
 |  | 
 |         def prop_shapes_on_graph(inp0, inp1): | 
 |             inputs[0].setType(inputs[0].type().with_sizes(inp0)) | 
 |             inputs[1].setType(inputs[1].type().with_sizes(inp1)) | 
 |             torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) | 
 |  | 
 |         prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5]) | 
 |         FileCheck().check("1, 7, 6, 5").run(foo.graph) | 
 |  | 
 |         # None implicitly creates a new symbolic symbol | 
 |         prop_shapes_on_graph([None, None], [None, None, None]) | 
 |         output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes() | 
 |         inp0_shape = inputs[0].type().symbolic_sizes() | 
 |         inp1_shape = inputs[1].type().symbolic_sizes() | 
 |  | 
 |         # output shape dim 0 should be taken from the second inp dim0 | 
 |         # other two dims we cannot infer and are given a new symbolic shape | 
 |         self.assertEqual(output_shape[0], inp1_shape[0]) | 
 |         self.assertFalse(output_shape[1] in inp0_shape + inp1_shape) | 
 |         self.assertFalse(output_shape[2] in inp0_shape + inp1_shape) | 
 |  | 
 |         # XXX: symbolic shapes are represented with an increasing counter of unique | 
 |         # values, use `_new_symbolic_shape_symbol` api instead of specifying negative | 
 |         # dimensions directly so there is no chance of collision between manual number | 
 |         # and current counter value. | 
 |         sym1 = torch._C._new_symbolic_shape_symbol() | 
 |         sym2 = torch._C._new_symbolic_shape_symbol() | 
 |         sym3 = torch._C._new_symbolic_shape_symbol() | 
 |         prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3]) | 
 |         output_shape = foo.graph.findNode("aten::mul").output().type().symbolic_sizes() | 
 |         self.assertEqual(output_shape[0], sym1) | 
 |         self.assertEqual(output_shape[1], sym2) | 
 |         self.assertEqual(output_shape[2], sym3) | 
 |  | 
 |     def test_shared_shape_graph(self): | 
 |         @torch.jit.script | 
 |         def foo(x, y): | 
 |             return x * y, x / y | 
 |  | 
 |         mul_node = foo.graph.findNode("aten::mul") | 
 |         div_node = foo.graph.findNode("aten::div") | 
 |  | 
 |         mul_graph = torch._C._jit_shape_compute_graph_for_node(mul_node) | 
 |         div_graph = torch._C._jit_shape_compute_graph_for_node(div_node) | 
 |         self.assertIsNotNone(mul_graph) | 
 |         self.assertIs(mul_graph, div_graph) | 
 |  | 
 |     def test_write(self): | 
 |         @torch.jit.script | 
 |         def foo(a, b): | 
 |             return a * b | 
 |  | 
 |         # broadcast appends cant be removed, so we bail on propagation | 
 |         torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) | 
 |         FileCheck().check("Tensor = aten::mul").run(foo.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo(y): | 
 |             x = [1, 2, 3, 4] | 
 |             x[0] = 5 | 
 |             return y.view(x) | 
 |  | 
 |         torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) | 
 |         FileCheck().check("Tensor = aten::view").run(foo.graph) | 
 |  | 
 |     def test_if_propagation(self): | 
 |         @torch.jit.script | 
 |         def foo(i: int, z): | 
 |             x = torch.ones([2, 3, 4, 5]) | 
 |             y = z.view([z.size(i), 3, 2, z.size(i)]) | 
 |             if i == 4: | 
 |                 return x | 
 |             else: | 
 |                 return y | 
 |  | 
 |         torch._C._jit_pass_constant_propagation(foo.graph) | 
 |         torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) | 
 |         view = foo.graph.findNode("aten::view") | 
 |  | 
 |         def neg_to_one(li): | 
 |             return [elem if elem >= 0 else -1 for elem in li] | 
 |  | 
 |         self.assertEqual(neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1]) | 
 |         if_out = next(foo.graph.findNode("prim::If").outputs()) | 
 |         self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1]) | 
 |  | 
 |     def test_unary_shape_functions(self): | 
 |         unary_ops = [ | 
 |             torch.nn.functional.hardtanh, | 
 |         ] | 
 |         for fn in unary_ops: | 
 |             t = torch.jit.trace(fn, (torch.rand([4, 4]))) | 
 |             ten_input = next(t.graph.inputs()) | 
 |             ten_input.setType(ten_input.type().with_sizes([2, 2])) | 
 |             torch._C._jit_pass_propagate_shapes_on_graph(t.graph) | 
 |             self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2]) | 
 |  | 
 |     def test_unary_shape_fns_inplace(self): | 
 |         def mul_inplace(x: torch.Tensor): | 
 |             y = x.mul_(2) | 
 |             return y | 
 |  | 
 |         unary_ops = [ | 
 |             mul_inplace | 
 |         ] | 
 |         for fn in unary_ops: | 
 |             # t = torch.jit.trace(fn, torch.rand([4, 4]))  # For some reason tracing is erroring out. | 
 |             t = torch.jit.script(fn) | 
 |             ten_input = next(t.graph.inputs()) | 
 |             ten_input.setType(ten_input.type().with_sizes([2, 2])) | 
 |             torch._C._jit_pass_propagate_shapes_on_graph(t.graph) | 
 |             self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [2, 2]) | 
 |  | 
 |     def test_binary_shape_functions(self): | 
 |         binary_ops = [ | 
 |             operator.__mul__, | 
 |             operator.__truediv__, | 
 |             operator.__gt__, | 
 |             operator.__add__, | 
 |         ] | 
 |  | 
 |         for fn in binary_ops: | 
 |             size_1 = [1, 4, 8] | 
 |             size_2 = [4, 1, 8] | 
 |             t = torch.jit.trace(fn, (torch.rand([4]), torch.rand([4]))) | 
 |             inputs = list(t.graph.inputs()) | 
 |             inputs[0].setType(inputs[0].type().with_sizes(size_1)) | 
 |             inputs[1].setType(inputs[1].type().with_sizes(size_2)) | 
 |             torch._C._jit_pass_propagate_shapes_on_graph(t.graph) | 
 |             self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8]) | 
 |             break | 
 |  | 
 |     def test_binary_shape_fns_inplace(self): | 
 |         def div_inplace_tensor(x: torch.Tensor, y: torch.Tensor): | 
 |             z = x.div_(y) | 
 |             return z | 
 |  | 
 |         def add_inplace_tensor(x: torch.Tensor, y: torch.Tensor): | 
 |             z = x.add_(y) | 
 |             return z | 
 |  | 
 |         binary_ops = [ | 
 |             div_inplace_tensor, | 
 |             add_inplace_tensor, | 
 |         ] | 
 |  | 
 |         for fn in binary_ops: | 
 |             size_1 = [4, 4, 8]  # x (can't broadcast because it's an inplace op) | 
 |             t = torch.jit.script(fn) | 
 |             inputs = list(t.graph.inputs()) | 
 |             inputs[0].setType(inputs[0].type().with_sizes(size_1)) | 
 |             # Intentionally not populate the type of inputs[1] | 
 |             torch._C._jit_pass_propagate_shapes_on_graph(t.graph) | 
 |             self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8]) | 
 |  | 
 |     def test_size_and_sizes(self): | 
 |         @torch.jit.script | 
 |         def foo(x, y): | 
 |             return x.view(y.size(0), 8, y.size(-1)) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo2(x, y): | 
 |             return x.view(y.size()) | 
 |  | 
 |         for graph in [foo.graph, foo2.graph]: | 
 |             inputs = list(graph.inputs()) | 
 |             sym1 = torch._C._new_symbolic_shape_symbol() | 
 |  | 
 |             inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1])) | 
 |             torch._C._jit_pass_propagate_shapes_on_graph(graph) | 
 |             self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]) | 
 |  | 
 |     def test_adaptive_avg_pool2d(self): | 
 |         inps = [ | 
 |             [(1, 64, 8, 9), (5, 7)], | 
 |             [(1, 64, 10, 9), (7)], | 
 |             [(1, 64, 10, 9), (5, None)], | 
 |             [(1, 8, 4, 3), (None, None)], | 
 |             [(1, 8, 4, 3), (None, 5)], | 
 |         ] | 
 |  | 
 |         for inp in inps: | 
 |             t = torch.randn(*inp[0]) | 
 |             out_size = torch.nn.functional.adaptive_avg_pool2d(t, inp[1]).size() | 
 |  | 
 |             def foo(x): | 
 |                 return torch.nn.functional.adaptive_avg_pool2d(x, inp[1]) | 
 |  | 
 |             fn = torch.jit.trace(foo, (t,)) | 
 |             torch._C._jit_erase_non_input_shape_information(fn.graph) | 
 |             torch._C._jit_pass_peephole(fn.graph) | 
 |             torch._C._jit_pass_constant_propagation(fn.graph) | 
 |             self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True) | 
 |  | 
 |     def test_arange_shape(self): | 
 |         # no opinfo for tensor constructors | 
 |         inps = [ | 
 |             (10,), | 
 |             (10, 10), | 
 |             (0, 10), | 
 |             (0, 1000), | 
 |             (1, -1, -1), | 
 |             (1, 0, -1), | 
 |             (1, 2, 1), | 
 |             (0.6, 0.89, 0.1), | 
 |             (1, 10, 0.3), | 
 |             (1, 10, 4), | 
 |             (0.6, 0.7, 0.8), | 
 |             (1, 10, 0.3), | 
 |             # (True,),  TODO: https://github.com/pytorch/pytorch/issues/63405 | 
 |             # (False,), TODO: https://github.com/pytorch/pytorch/issues/63405 | 
 |             (0, 5), | 
 |             (0, 5, 2), | 
 |             (0, 5 + 1e-6), | 
 |             (0, 5 - 1e-6), | 
 |             (10, -1 + 1e-6, -1), | 
 |             (10, -1, -1), | 
 |             (10, -1 - 1e-6, -1), | 
 |         ] | 
 |  | 
 |         for inp in inps: | 
 |             funcs_template = dedent(''' | 
 |             def func(): | 
 |                 return torch.arange({args}) | 
 |             ''') | 
 |  | 
 |             inp_s = str(inp)[1:-1]  # remove tuple parens | 
 |             funcs_str = funcs_template.format(args=inp_s) | 
 |             scope = {} | 
 |             execWrapper(funcs_str, globals(), scope) | 
 |             cu = torch.jit.CompilationUnit(funcs_str) | 
 |             self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False) | 
 |  | 
 |     def test_shape_embedding_bag(self): | 
 |         # TODO: merge into opinfos, having difficulties there | 
 |         with torch.no_grad(): | 
 |             def make_arg(shape, low=None, high=None): | 
 |                 return make_tensor(shape, device='cpu', dtype=torch.int64, | 
 |                                    low=low, high=high, requires_grad=False) | 
 |  | 
 |             nn_inps = ( | 
 |                 (make_arg((40,), 0, 9), torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0)), | 
 |                 (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)), | 
 |                 (make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)), | 
 |                 (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)), | 
 |                 (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)), | 
 |                 (make_arg((2,), 0, 1), torch.nn.Embedding.from_pretrained(torch.arange(6.).view(2, 3), max_norm=2., | 
 |                                                                           norm_type=.5, scale_grad_by_freq=False, sparse=True)), | 
 |             ) | 
 |  | 
 |             for inp, module in nn_inps: | 
 |                 kwargs = { | 
 |                     "weight": module.weight.detach(), | 
 |                     "padding_idx": module.padding_idx, | 
 |                     "max_norm": module.max_norm, | 
 |                     "norm_type": module.norm_type, | 
 |                     "scale_grad_by_freq": module.scale_grad_by_freq, | 
 |                     "sparse": module.sparse, | 
 |                 } | 
 |  | 
 |                 out_size = torch.nn.functional.embedding(inp, **kwargs).size() | 
 |  | 
 |                 def foo(x): | 
 |                     return torch.nn.functional.embedding(inp, **kwargs) | 
 |  | 
 |                 fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False) | 
 |  | 
 |                 self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False) | 
 |  | 
 |     def test_shape_concat(self): | 
 |         # TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR | 
 |         sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False) | 
 |  | 
 |         class CatMod(nn.Module): | 
 |             __constants__ = ['dim'] | 
 |  | 
 |             def __init__(self, dim=0): | 
 |                 super(CatMod, self).__init__() | 
 |                 self.dim = dim | 
 |  | 
 |             def forward(self, x, y): | 
 |                 return torch.cat([x, y], dim=self.dim) | 
 |  | 
 |         for inp in sample_inputs: | 
 |             mod = torch.jit.script(CatMod(**inp.kwargs).eval()) | 
 |  | 
 |             args = inp.input | 
 |             self.assertTrue(len(args) == 2) | 
 |             out_size = mod(*args).size() | 
 |             inps = list(mod.graph.inputs()) | 
 |             inps[1].setType(inps[1].type().with_sizes(args[0].size())) | 
 |             inps[2].setType(inps[2].type().with_sizes(args[1].size())) | 
 |             self.checkShapeAnalysis(out_size, mod.graph, assert_propagation=True) | 
 |  | 
 |     def assert_shape_equal_scripted(self, script_fn, given_ins): | 
 |         expected_res = script_fn(*given_ins) | 
 |         g = script_fn.graph | 
 |         graph_ins = list(g.inputs()) | 
 |         self.assertEqual(len(given_ins), len(graph_ins)) | 
 |         for inp, graph_in in zip(given_ins, graph_ins): | 
 |             graph_in.setType(graph_in.type().with_sizes(inp.size())) | 
 |  | 
 |         out_sizes = [out.size() for out in expected_res] | 
 |         self.checkShapeAnalysis(out_sizes, g, assert_propagation=True) | 
 |  | 
 |     def test_convolution_backward(self): | 
 |         # No opinfos for ops that are not part of the Python API | 
 |         # Also, as the return shapes are the input, weight, and bias shape, there is no point | 
 |         # in a really complicated test | 
 |  | 
 |         input = torch.randn((16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True) | 
 |         weight = torch.randn((8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True) | 
 |         out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu") | 
 |  | 
 |  | 
 |         @torch.jit.script | 
 |         def conv_bwd(input, weight, grad): | 
 |             bias_sizes = [8, ] | 
 |             args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True]) | 
 |             return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args) | 
 |  | 
 |         self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad)) | 
 |  | 
 |         @torch.jit.script | 
 |         def conv_bwd_2(input, weight, grad): | 
 |             bias_sizes = None | 
 |             args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True]) | 
 |             return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args) | 
 |         self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad)) | 
 |  | 
 |  | 
 |     def test_returning_input_symbolic_shapes(self): | 
 |         mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval())) | 
 |         inps = list(mm.graph.inputs()) | 
 |         inps[1].setType(inps[1].type().with_sizes([None, None, None, None])) | 
 |         shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph) | 
 |         g = shape_compute_graph.partial_eval_shape_graph() | 
 |         # to make into a jit function cant have multiple outputs | 
 |         g.makeMultiOutputIntoTuple() | 
 |         func = torch._C._create_function_from_graph("partial_eval_graph", g) | 
 |         out = func([20, 16, 5, 10]) | 
 |         # first four outputs should be unknown symbolic shapes from input | 
 |         self.assertEqual(out[0:4], [20, 16, 5, 10]) | 
 |         # last two are two new symbolic dims - height and width | 
 |         self.assertEqual(out[4:], list(mm(torch.rand([20, 16, 5, 10])).size()[2:])) | 
 |  | 
 |     def test_partial_eval_graph_conv(self): | 
 |         mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval())) | 
 |         shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph) | 
 |         output_sizes = mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes() | 
 |         # calculating 0, 2 and 3 index | 
 |         for i in [0, 2, 3]: | 
 |             self.assertTrue(output_sizes[i] < 0) | 
 |         self.assertTrue(output_sizes[1] >= 0) | 
 |         g = shape_compute_graph.partial_eval_shape_graph() | 
 |         # to make into a jit function cant have multiple outputs | 
 |         g.makeMultiOutputIntoTuple() | 
 |         func = torch._C._create_function_from_graph("partial_eval_graph", g) | 
 |         inp = torch.randn(20, 16, 5, 10) | 
 |         output = func([20, 16, 5, 10]) | 
 |         output_eager = list(mm(inp).size()) | 
 |         for o, oe in zip(output, output_eager[0:1] + output_eager[2:]): | 
 |             self.assertEqual(o, oe) | 
 |  | 
 |     def checkSymShapeCompute(self, shape_compute_graph, nodes, node_output_sizes, shape_inputs): | 
 |         g = shape_compute_graph.partial_eval_shape_graph() | 
 |         self.assertTrue(len(list(g.inputs())) == len(shape_inputs)) | 
 |         output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim() | 
 |         # map from sym shape -> index | 
 |         sym_shape_to_index = {} | 
 |         for index, output in enumerate(g.outputs()): | 
 |             sym_shape_to_index[output_sym_map[output]] = index | 
 |  | 
 |         g.makeMultiOutputIntoTuple() | 
 |         func = torch._C._create_function_from_graph("partial_eval_graph", g) | 
 |         sym_outputs = func(*shape_inputs) | 
 |  | 
 |         for node, output_shape in zip(nodes, node_output_sizes): | 
 |             output_type_sizes = node.output().type().symbolic_sizes() | 
 |             for i, sym_shape in enumerate(output_type_sizes): | 
 |                 if sym_shape >= 0: | 
 |                     self.assertEqual(sym_shape, output_shape[i]) | 
 |                 else: | 
 |                     sym_shape_index = sym_shape_to_index[sym_shape] | 
 |                     self.assertEqual(sym_outputs[sym_shape_index], output_shape[i]) | 
 |  | 
 |     def test_partial_eval_stitching(self): | 
 |         conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | 
 |         max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) | 
 |         conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) | 
 |  | 
 |         mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval())) | 
 |  | 
 |         conv1_output = conv1(torch.rand(1, 3, 224, 224)) | 
 |         max_pool_output = max_pool(conv1_output) | 
 |         conv2_output = conv2(max_pool_output) | 
 |  | 
 |         shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) | 
 |         nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d")) | 
 |         output_shapes = [max_pool_output.size(), conv1_output.size(), conv2_output.size()] | 
 |         self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],)) | 
 |  | 
 |     def test_refinement_through_graph_stitching(self): | 
 |         class TwoConvs(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(TwoConvs, self).__init__() | 
 |                 self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | 
 |                 self.conv2 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | 
 |  | 
 |             def forward(self, x): | 
 |                 a = self.conv1(x) | 
 |                 b = self.conv2(x) | 
 |                 return a + b | 
 |  | 
 |         mod = torch.jit.freeze(torch.jit.script(TwoConvs()).eval()) | 
 |         inp_tensor = list(mod.graph.inputs())[1] | 
 |         inp_tensor.setType(inp_tensor.type().with_sizes([None, None, None, None])) | 
 |         torch._C._jit_pass_propagate_shapes_on_graph(mod.graph) | 
 |         outs = list(next(mod.graph.outputs()).node().inputs()) | 
 |         out1 = outs[0].type().symbolic_sizes() | 
 |         out2 = outs[1].type().symbolic_sizes() | 
 |         self.assertTrue(out1[2] != out2[2]) | 
 |         self.assertTrue(out1[3] != out2[3]) | 
 |         # by joining partial eval graphs of both convs we are able to recognize the output shapes | 
 |         # are equivalent | 
 |         torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) | 
 |         out1 = outs[0].type().symbolic_sizes() | 
 |         out2 = outs[1].type().symbolic_sizes() | 
 |         self.assertEqual(out1, out2) | 
 |  | 
 |     def test_stitching_multi_output(self): | 
 |         max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, return_indices=True) | 
 |         tensor = torch.rand(1, 3, 224, 224) | 
 |         mod = torch.jit.trace(max_pool, (tensor,)) | 
 |         mod = torch.jit.freeze(mod.eval()) | 
 |         inp = list(mod.graph.inputs())[1] | 
 |         inp.setType(inp.type().with_sizes([None, None, None, None])) | 
 |         output_tensor = list(mod(tensor)[0].size()) | 
 |         self.run_pass('lower_all_tuples', mod.graph) | 
 |         shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) | 
 |         max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices") | 
 |         outs = list(max_pool_node.outputs()) | 
 |         self.assertEqual(outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes()) | 
 |         g = shape_compute_graph.partial_eval_shape_graph() | 
 |         # to make into a jit function cant have multiple outputs | 
 |         g.makeMultiOutputIntoTuple() | 
 |         func = torch._C._create_function_from_graph("partial_eval_graph", g) | 
 |         mapping = shape_compute_graph.graph_output_to_symbolic_shape_dim() | 
 |         output_shape = func(tensor.size()) | 
 |         # the first 4 dims are input sym dimensions, then the , | 
 |         self.assertEqual(list(output_shape[0:4]), list(tensor.size())) | 
 |         self.assertEqual(list(output_shape[4:]), output_tensor[2:]) | 
 |  | 
 |     def test_sym_ir_parsing(self): | 
 |         graph_str1 = """graph(%x.1 : Float(SS(-2), SS(-3))): | 
 |                         %3 : int = prim::Constant[value=1]() | 
 |                         %4 : Tensor = aten::add(%x.1, %x.1, %3) | 
 |                         return (%4)""" | 
 |         g = torch._C.parse_ir(graph_str1) | 
 |         inp = next(g.inputs()) | 
 |         out = inp.type().symbolic_sizes() | 
 |         self.assertEqual(out, [-2, -3]) | 
 |  | 
 |     def test_stitching_concat(self): | 
 |  | 
 |         @torch.jit.script | 
 |         def foo1(a, b, x, y): | 
 |             return (a / b) + torch.cat([x, y]) | 
 |  | 
 |         @torch.jit.script | 
 |         def foo2(a, b, x, y): | 
 |             return (a / b) + torch.cat([x, y], dim=-2) | 
 |  | 
 |         for foo in [foo1, foo2]: | 
 |             g = foo.graph | 
 |             for inp in foo.graph.inputs(): | 
 |                 inp.setType(inp.type().with_sizes([None, None])) | 
 |  | 
 |             shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph) | 
 |             nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")] | 
 |  | 
 |             inps = [1, 10], [20, 10], [15, 1], [5, 1] | 
 |             output_shapes = [[20, 10], [20, 10], [20, 1]] | 
 |  | 
 |             self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps) | 
 |  | 
 |     @unittest.skipIf(not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python") | 
 |     def test_shape_function_includes(self): | 
 |         inp_shape = [1, 16, 5, 10] | 
 |         weight_shape = [33, 16, 3, 3] | 
 |         bias = None | 
 |         stride = [2, 2] | 
 |         padding = [0, 0] | 
 |         dilation = [1, 1] | 
 |         groups = 1 | 
 |         res = torch.jit._shapes.conv2d(inp_shape, weight_shape, bias, stride, padding, dilation, groups) | 
 |         self.assertEqual(res, [1, 33, 2, 4]) | 
 |  | 
 |         m1_shape = [10, 20] | 
 |         m2_shape = [20, 10] | 
 |         res = torch.jit._shapes.matmul(m1_shape, m2_shape) | 
 |         self.assertEqual(res, [10, 10]) | 
 |  | 
 |     def test_register_function_error_checking(self): | 
 |         # this will error before registering on global map, so | 
 |         # no issue in overwriting schema mappings | 
 |         @torch.jit.script | 
 |         def foo(x, y): | 
 |             return x + y | 
 |  | 
 |         node = foo.graph.findNode("aten::add") | 
 |  | 
 |         @torch.jit.script | 
 |         def wrong_input_types(x, y): | 
 |             x: List[int] = [] | 
 |             return x | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"): | 
 |             torch._C._jit_register_shape_compute_graph_for_node(node, wrong_input_types.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def wrong_output_types(x: List[int], y: List[int]): | 
 |             x: List[Tensor] = [] | 
 |             return x | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "but got graph_type"): | 
 |             torch._C._jit_register_shape_compute_graph_for_node(node, wrong_output_types.graph) | 
 |  | 
 |         @torch.jit.script | 
 |         def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any): | 
 |             x: List[int] = [] | 
 |             return x | 
 |  | 
 |         with self.assertRaises(RuntimeError) as error: | 
 |             torch._C._jit_register_shape_compute_graph_for_node(node, too_many_inputs.graph) | 
 |  | 
 |         self.assertTrue("fewer arguments than schema" in str(error.exception)) |