Fix aten::to symbolic + add expand_as (#13325)

Summary:
https://github.com/pytorch/pytorch/pull/13146 broke some cases of ONNX export, this fixes them
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13325

Differential Revision: D12844294

Pulled By: jamesr66a

fbshipit-source-id: f98dd0685820b2a1e5fcd49733cfa5c19c48a4e7
diff --git a/test/expect/TestJit.test_export_tensoroption_to.expect b/test/expect/TestJit.test_export_tensoroption_to.expect
new file mode 100644
index 0000000..fc7ab97
--- /dev/null
+++ b/test/expect/TestJit.test_export_tensoroption_to.expect
@@ -0,0 +1,22 @@
+ModelProto {
+  producer_name: "pytorch"
+  domain: ""
+  doc_string: ""
+  graph:
+    GraphProto {
+      name: "torch-jit-export"
+      inputs: [{name: "0", type:Tensor dims: 2}]
+      outputs: [{name: "7", type:Tensor dims: 2}]
+      initializers: []
+      nodes: [
+        Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+        Node {type: "Gather", inputs: [0,1], outputs: [2], attributes: [{ name: 'axis', type: int, value: 0}]},
+        Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+        Node {type: "Shape", inputs: [3], outputs: [4], attributes: []},
+        Node {type: "Expand", inputs: [2,4], outputs: [5], attributes: []},
+        Node {type: "Cast", inputs: [5], outputs: [6], attributes: [{ name: 'to', type: int, value: 1}]},
+        Node {type: "Add", inputs: [6,0], outputs: [7], attributes: []}
+      ]
+    }
+  opset_import: [OperatorSetIdProto { domain: }],
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index c2b7d98..5d6007f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1917,6 +1917,17 @@
         x = torch.rand(3, 4)
         self.assertEqual(random_bar(x), (x + 1)[0:1])
 
+    def test_export_tensoroption_to(self):
+        def foo(x):
+            return x.new_tensor(x[0]).cpu() + x
+
+        traced = torch.jit.trace(foo, (torch.rand([2])))
+        example_outputs = traced(torch.rand([2]))
+
+        f = io.BytesIO()
+        self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
+                                                                example_outputs=example_outputs))
+
     def test_pretty_printer(self):
         @torch.jit.script
         def if_test(a, b):
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 7efdf20..f3963b3 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -350,6 +350,11 @@
     return None
 
 
+def expand_as(g, self, other):
+    shape = g.op("Shape", other)
+    return g.op("Expand", self, shape)
+
+
 def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
     return g.op("Gather", weight, indices)
 
@@ -1097,6 +1102,11 @@
         # aten::to(Tensor, Device, ScalarType, bool, bool)
         dtype = _get_const(args[1], 'i', 'dtype')
         return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+    elif len(args) == 5:
+        # aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
+        dtype = _get_const(args[0], 'i', 'dtype')
+        # Layout and device are ignored
+        return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
     else:
         raise NotImplementedError("Unknown aten::to signature")