Export ones_like, zeros_like and full_like using ONNX ConstantLike op. (#14903)
Summary:
This PR does the following:
1) Updates the ONNX export for `torch.zeros_like` and `torch.full_like` ops to use ONNX op `ConstantLike`. This reduces the export of experimental op `ConstantFill`, which may possibly be removed in future, see https://github.com/onnx/onnx/pull/1434).
2) It also adds export support for `torch.ones_like`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14903
Differential Revision: D13383700
Pulled By: houseroad
fbshipit-source-id: 566d00a943e9497172fcd5a034b638a650ab13a2
diff --git a/test/onnx/expect/TestOperators.test_full_like.expect b/test/onnx/expect/TestOperators.test_full_like.expect
new file mode 100644
index 0000000..1932f54
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_full_like.expect
@@ -0,0 +1,56 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "0"
+ output: "1"
+ op_type: "ConstantLike"
+ attribute {
+ name: "dtype"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "value"
+ f: 2
+ type: FLOAT
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
diff --git a/test/onnx/expect/TestOperators.test_ones_like.expect b/test/onnx/expect/TestOperators.test_ones_like.expect
new file mode 100644
index 0000000..96016f3
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_ones_like.expect
@@ -0,0 +1,56 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "0"
+ output: "1"
+ op_type: "ConstantLike"
+ attribute {
+ name: "dtype"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "value"
+ f: 1
+ type: FLOAT
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 6
+ }
+ dim {
+ dim_value: 10
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 6
+ }
+ dim {
+ dim_value: 10
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
diff --git a/test/onnx/expect/TestOperators.test_zeros_like.expect b/test/onnx/expect/TestOperators.test_zeros_like.expect
new file mode 100644
index 0000000..c21b4e9
--- /dev/null
+++ b/test/onnx/expect/TestOperators.test_zeros_like.expect
@@ -0,0 +1,56 @@
+ir_version: 3
+producer_name: "pytorch"
+producer_version: "0.4"
+graph {
+ node {
+ input: "0"
+ output: "1"
+ op_type: "ConstantLike"
+ attribute {
+ name: "dtype"
+ i: 1
+ type: INT
+ }
+ attribute {
+ name: "value"
+ f: 0
+ type: FLOAT
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "0"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 5
+ }
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 5
+ }
+ dim {
+ dim_value: 8
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 9
+}
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index 9a27011..d472e6a 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -292,6 +292,10 @@
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.full(x.shape, 2), x)
+ def test_full_like(self):
+ x = torch.randn(3, 4, requires_grad=True)
+ self.assertONNX(lambda x: torch.full_like(x, 2), x)
+
def test_max(self):
x = torch.randn(3, 4, requires_grad=True)
y = torch.randn(3, 4, requires_grad=True)
@@ -475,6 +479,13 @@
x = torch.randn(3, 4)
self.assertONNX(torch.nn.Linear(4, 5, bias=True), x)
+ def test_zeros_like(self):
+ x = torch.randn(5, 8, requires_grad=True)
+ self.assertONNX(lambda x: torch.zeros_like(x), x)
+
+ def test_ones_like(self):
+ x = torch.randn(6, 10, requires_grad=True)
+ self.assertONNX(lambda x: torch.ones_like(x), x)
if __name__ == '__main__':
no_onnx_dep_flag = '--no-onnx'
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index aace3b6..303a183 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -1023,8 +1023,9 @@
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=0)
-def zeros_like(g, input):
- return g.op("Sub", input, input).setType(input.type().contiguous())
+@parse_args('v', 'i', 'v', 'v')
+def zeros_like(g, input, dtype, layout, device):
+ return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=0.0)
@parse_args('v', 'i', 'v', 'v')
@@ -1032,6 +1033,11 @@
return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=1)
+@parse_args('v', 'i', 'v', 'v')
+def ones_like(g, input, dtype, layout, device):
+ return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=1.0)
+
+
def full(g, sizes, value, dtype, layout, device):
const_value = _maybe_get_const(value, 't')
if _is_value(const_value):
@@ -1043,9 +1049,9 @@
input_as_shape_i=1, value_f=const_value)
-def full_like(g, input, fill_value):
- # TODO: a more efficient implementation (ConstantFill?)
- return add(g, zeros_like(g, input), fill_value, g.op("Constant", value_t=torch.tensor(1)))
+@parse_args('v', 'f', 'i', 'v', 'v')
+def full_like(g, input, fill_value, dtype, layout, device):
+ return g.op("ConstantLike", input, dtype_i=scalar_type_to_onnx[dtype], value_f=fill_value)
@parse_args('v', 'v', 'v', 'v', 'i')