Fix ONNX bug, add symbolic for full
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12052
Differential Revision: D10044910
Pulled By: apaszke
fbshipit-source-id: 015ef372966d7594e1b450e348d457429f6ef20d
diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect
new file mode 100644
index 0000000..db97532
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_full.expect
@@ -0,0 +1,148 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ output: "1"
+ op_type: "Constant"
+ attribute {
+ name: "value"
+ t {
+ data_type: INT64
+ raw_data: "\000\000\000\000\000\000\000\000"
+ }
+ type: TENSOR
+ }
+ }
+ node {
+ input: "0"
+ output: "2"
+ op_type: "Shape"
+ }
+ node {
+ input: "2"
+ input: "1"
+ output: "3"
+ op_type: "Gather"
+ attribute {
+ name: "axis"
+ i: 0
+ type: INT
+ }
+ }
+ node {
+ output: "4"
+ op_type: "Constant"
+ attribute {
+ name: "value"
+ t {
+ data_type: INT64
+ raw_data: "\001\000\000\000\000\000\000\000"
+ }
+ type: TENSOR
+ }
+ }
+ node {
+ input: "0"
+ output: "5"
+ op_type: "Shape"
+ }
+ node {
+ input: "5"
+ input: "4"
+ output: "6"
+ op_type: "Gather"
+ attribute {
+ name: "axis"
+ i: 0
+ type: INT
+ }
+ }
+ node {
+ input: "3"
+ output: "7"
+ op_type: "Unsqueeze"
+ attribute {
+ name: "axes"
+ ints: 0
+ type: INTS
+ }
+ }
+ node {
+ input: "6"
+ output: "8"
+ op_type: "Unsqueeze"
+ attribute {
+ name: "axes"
+ ints: 0
+ type: INTS
+ }
+ }
+ node {
+ input: "7"
+ input: "8"
+ output: "9"
+ op_type: "Concat"
+ attribute {
+ name: "axis"
+ i: 0
+ type: INT
+ }
+ }
+ node {
+ input: "9"
+ output: "10"
+ op_type: "ConstantFill"
+ attribute {
+ name: "dtype"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "input_as_shape"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "value"
+ f: 2
+ type: FLOAT
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "10"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index 2dfdd40..1e2d0ff 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -287,6 +287,10 @@
x = Variable(torch.randn(3, 4), requires_grad=True)
self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
+ def test_full(self):
+ x = torch.randn(3, 4, requires_grad=True)
+ self.assertONNX(lambda x: torch.full(x.shape, 2), x)
+
def test_max(self):
x = Variable(torch.randn(3, 4), requires_grad=True)
y = Variable(torch.randn(3, 4), requires_grad=True)
diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp
index 8045a46..6e90780 100644
--- a/torch/csrc/jit/passes/onnx/peephole.cpp
+++ b/torch/csrc/jit/passes/onnx/peephole.cpp
@@ -159,7 +159,7 @@
}
if (n->kind() == onnx::Transpose) {
if (isNopTranspose(n->is(attr::perm))) {
- n->replaceAllUsesWith(n->input()->node());
+ n->output()->replaceAllUsesWith(n->input());
it.destroyCurrent();
continue;
}
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index e146863..91de860e 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -975,13 +975,24 @@
]
-@parse_args('v', 'i', 'i', 'v')
+@parse_args('v', 'i', 'v', 'v')
def zeros(g, shape, scalar_type, layout, device):
# NOTE: no way to set device in ONNX, so we ignore it
return g.op("ConstantFill", shape, dtype_i=scalar_type_to_onnx[scalar_type],
input_as_shape_i=1, value_f=0)
+def full(g, shape, value, scalar_type, layout, device):
+ const_value = _maybe_get_const(value, 't')
+ if _is_value(const_value):
+ tmp = zeros(shape, scalar_type, layout, device)
+ return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
+ else:
+ scalar_type = _get_const(scalar_type, 'i', 'dtype')
+ return g.op("ConstantFill", shape, dtype_i=scalar_type_to_onnx[scalar_type],
+ input_as_shape_i=1, value_f=const_value)
+
+
def full_like(g, input, fill_value):
# TODO: a more efficient implementation (ConstantFill?)
return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1)))