Fix ONNX bug, add symbolic for full

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12052

Differential Revision: D10044910

Pulled By: apaszke

fbshipit-source-id: 015ef372966d7594e1b450e348d457429f6ef20d
diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect
new file mode 100644
index 0000000..db97532
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_full.expect
@@ -0,0 +1,148 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+  node {
+    output: "1"
+    op_type: "Constant"
+    attribute {
+      name: "value"
+      t {
+        data_type: INT64
+        raw_data: "\000\000\000\000\000\000\000\000"
+      }
+      type: TENSOR
+    }
+  }
+  node {
+    input: "0"
+    output: "2"
+    op_type: "Shape"
+  }
+  node {
+    input: "2"
+    input: "1"
+    output: "3"
+    op_type: "Gather"
+    attribute {
+      name: "axis"
+      i: 0
+      type: INT
+    }
+  }
+  node {
+    output: "4"
+    op_type: "Constant"
+    attribute {
+      name: "value"
+      t {
+        data_type: INT64
+        raw_data: "\001\000\000\000\000\000\000\000"
+      }
+      type: TENSOR
+    }
+  }
+  node {
+    input: "0"
+    output: "5"
+    op_type: "Shape"
+  }
+  node {
+    input: "5"
+    input: "4"
+    output: "6"
+    op_type: "Gather"
+    attribute {
+      name: "axis"
+      i: 0
+      type: INT
+    }
+  }
+  node {
+    input: "3"
+    output: "7"
+    op_type: "Unsqueeze"
+    attribute {
+      name: "axes"
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
+    input: "6"
+    output: "8"
+    op_type: "Unsqueeze"
+    attribute {
+      name: "axes"
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
+    input: "7"
+    input: "8"
+    output: "9"
+    op_type: "Concat"
+    attribute {
+      name: "axis"
+      i: 0
+      type: INT
+    }
+  }
+  node {
+    input: "9"
+    output: "10"
+    op_type: "ConstantFill"
+    attribute {
+      name: "dtype"
+      i: 1
+      type: INT
+    }
+    attribute {
+      name: "input_as_shape"
+      i: 1
+      type: INT
+    }
+    attribute {
+      name: "value"
+      f: 2
+      type: FLOAT
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "0"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "10"
+    type {
+      tensor_type {
+        elem_type: FLOAT
+        shape {
+          dim {
+            dim_value: 3
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 9
+}
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index 2dfdd40..1e2d0ff 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -287,6 +287,10 @@
         x = Variable(torch.randn(3, 4), requires_grad=True)
         self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
 
+    def test_full(self):
+        x = torch.randn(3, 4, requires_grad=True)
+        self.assertONNX(lambda x: torch.full(x.shape, 2), x)
+
     def test_max(self):
         x = Variable(torch.randn(3, 4), requires_grad=True)
         y = Variable(torch.randn(3, 4), requires_grad=True)
diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp
index 8045a46..6e90780 100644
--- a/torch/csrc/jit/passes/onnx/peephole.cpp
+++ b/torch/csrc/jit/passes/onnx/peephole.cpp
@@ -159,7 +159,7 @@
     }
     if (n->kind() == onnx::Transpose) {
       if (isNopTranspose(n->is(attr::perm))) {
-        n->replaceAllUsesWith(n->input()->node());
+        n->output()->replaceAllUsesWith(n->input());
         it.destroyCurrent();
         continue;
       }
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index e146863..91de860e 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -975,13 +975,24 @@
 ]
 
 
-@parse_args('v', 'i', 'i', 'v')
+@parse_args('v', 'i', 'v', 'v')
 def zeros(g, shape, scalar_type, layout, device):
     # NOTE: no way to set device in ONNX, so we ignore it
     return g.op("ConstantFill", shape, dtype_i=scalar_type_to_onnx[scalar_type],
                 input_as_shape_i=1, value_f=0)
 
 
+def full(g, shape, value, scalar_type, layout, device):
+    const_value = _maybe_get_const(value, 't')
+    if _is_value(const_value):
+        tmp = zeros(shape, scalar_type, layout, device)
+        return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
+    else:
+        scalar_type = _get_const(scalar_type, 'i', 'dtype')
+        return g.op("ConstantFill", shape, dtype_i=scalar_type_to_onnx[scalar_type],
+                    input_as_shape_i=1, value_f=const_value)
+
+
 def full_like(g, input, fill_value):
     # TODO: a more efficient implementation (ConstantFill?)
     return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1)))