[ONNX] Add col2im for opset 18 (#84594)
Opset 18 will be used to introduce suport for ONNX's Col2Im-18 and resolve https://github.com/pytorch/pytorch/issues/84408
Depends: https://github.com/pytorch/pytorch/pull/83201 (CI will fail until ONNX submodule is updated)
as per Faith recommendation, this PR should be merged post ORT 1.13 only
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84594
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/abock, https://github.com/BowenBao
diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py
index eea86b2..15d9337 100644
--- a/test/onnx/test_pytorch_onnx_no_runtime.py
+++ b/test/onnx/test_pytorch_onnx_no_runtime.py
@@ -1156,6 +1156,39 @@
dim,
)
+ def test_col2im(self):
+ # This test can be moved to test/onnx/test_pytorch_onnx_onnxruntime.py when ORT implement ::Col2Im
+
+ # Random batched RGB 32x32 image-shaped input tensor of batch size 64
+ original_image_inputs = torch.randn((64, 3, 32, 32))
+ output_size = tuple(original_image_inputs.shape[2:])
+ kernel_size = (1, 2)
+ dilation = 3
+ padding = 2
+ stride = 1
+ model_im2col = torch.nn.Unfold(
+ kernel_size, dilation=dilation, padding=padding, stride=stride
+ )
+ blocks = model_im2col(original_image_inputs)
+
+ model = torch.nn.Fold(
+ output_size=output_size,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=padding,
+ stride=stride,
+ )
+ f = io.BytesIO()
+ torch.onnx.export(model, (blocks,), f, opset_version=18)
+
+ onnx_model = onnx.load(io.BytesIO(f.getvalue()))
+ self.assertEqual(onnx_model.graph.node[-1].op_type, "Col2Im")
+ self.assertEqual(onnx_model.graph.node[-1].domain, "")
+ self.assertEqual(len(onnx_model.graph.node[-1].input), 3)
+ self.assertEqual(onnx_model.graph.node[-1].attribute[0].name, "dilations")
+ self.assertEqual(onnx_model.graph.node[-1].attribute[1].name, "pads")
+ self.assertEqual(onnx_model.graph.node[-1].attribute[2].name, "strides")
+
class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 88c1819..80e530c 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -44,7 +44,9 @@
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
-MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
+MAX_ONNX_OPSET_VERSION = (
+ _constants.ONNX_MAX_OPSET - 1
+) # TODO: ORT does not support opset 18 yet
def _init_test_generalized_rcnn_transform():
diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp
index f83bc9e..fe240c5 100644
--- a/torch/csrc/jit/serialization/export.cpp
+++ b/torch/csrc/jit/serialization/export.cpp
@@ -59,7 +59,7 @@
namespace onnx = ::ONNX_NAMESPACE;
const static int kInvalidOpsetVersion = -1;
-const static int kMainOpsetVersion = 17;
+const static int kMainOpsetVersion = 18;
// Based on OP_SET_ID_VERSION_MAP in
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
@@ -82,6 +82,7 @@
8, // opset 15
8, // opset 16
8, // opset 17
+ 8, // opset 18
};
std::string getNodeStackTraceString(const Node* n) {
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index da86811..3c6b90b 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -25,6 +25,7 @@
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
+ symbolic_opset18,
utils,
)
@@ -62,6 +63,7 @@
"symbolic_opset15",
"symbolic_opset16",
"symbolic_opset17",
+ "symbolic_opset18",
# Enums
"ExportTypes",
"OperatorExportTypes",
diff --git a/torch/onnx/_constants.py b/torch/onnx/_constants.py
index ed27f94..e264660 100644
--- a/torch/onnx/_constants.py
+++ b/torch/onnx/_constants.py
@@ -4,7 +4,7 @@
ONNX_BASE_OPSET = 9
ONNX_MIN_OPSET = 7
-ONNX_MAX_OPSET = 17
+ONNX_MAX_OPSET = 18
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
ONNX_DEFAULT_OPSET = 14
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
diff --git a/torch/onnx/symbolic_opset18.py b/torch/onnx/symbolic_opset18.py
new file mode 100644
index 0000000..dee3378
--- /dev/null
+++ b/torch/onnx/symbolic_opset18.py
@@ -0,0 +1,70 @@
+"""This file exports ONNX ops for opset 18.
+
+Note [ONNX Operators that are added/updated in opset 18]
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
+New operators:
+ CenterCropPad
+ Col2Im
+ Mish
+ OptionalGetElement
+ OptionalHasElement
+ Pad
+ Resize
+ ScatterElements
+ ScatterND
+"""
+
+import functools
+from typing import Sequence
+
+from torch import _C
+from torch.onnx import symbolic_helper
+from torch.onnx._internal import _beartype, registration
+
+# EDITING THIS FILE? READ THIS FIRST!
+# see Note [Edit Symbolic Files] in symbolic_helper.py
+
+__all__ = ["col2im"]
+
+_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
+
+
+@_onnx_symbolic("aten::col2im")
+@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
+@_beartype.beartype
+def col2im(
+ g,
+ input: _C.Value,
+ output_size: _C.Value,
+ kernel_size: _C.Value,
+ dilation: Sequence[int],
+ padding: Sequence[int],
+ stride: Sequence[int],
+):
+ # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
+ adjusted_padding = []
+ for pad in padding:
+ for _ in range(2):
+ adjusted_padding.append(pad)
+
+ num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
+ if not adjusted_padding:
+ adjusted_padding = [0, 0] * num_dimensional_axis
+
+ if not dilation:
+ dilation = [1] * num_dimensional_axis
+
+ if not stride:
+ stride = [1] * num_dimensional_axis
+
+ return g.op(
+ "Col2Im",
+ input,
+ output_size,
+ kernel_size,
+ dilations_i=dilation,
+ pads_i=adjusted_padding,
+ strides_i=stride,
+ )