Support Exports to Multiple ONNX Opset (#19294)
Summary:
Support exporting multiple ONNX opsets (more specifically opset 10 for now), following the proposal in https://gist.github.com/spandantiwari/99700e60919c43bd167838038d20f353.
And add support for custom ops (merge with https://github.com/pytorch/pytorch/pull/18297).
This PR will be followed by another PR containing the changes related to testing the ops for different opsets.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19294
Reviewed By: zrphercule
Differential Revision: D15043951
Pulled By: houseroad
fbshipit-source-id: d336fc35b8827145639137bc348ae07e3c14bb1c
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index d0f2a3c..9939dbf 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -204,7 +204,7 @@
// 'onnx' symbols correspond to ONNX operators. Their semantics
// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
// The particular version we are targeting is specified by '_onnx_opset_version'
-// in torch.onnx.symbolic
+// in torch.onnx.symbolic_helper
//
// In general, most ONNX operators won't get an entry here, because they
// are handled from the Python end. However, you may occasionally need
diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst
index 90bf287..ef1dcad 100644
--- a/docs/source/onnx.rst
+++ b/docs/source/onnx.rst
@@ -212,10 +212,10 @@
If the operator is an ATen operator, which means you can find the declaration
of the function in ``torch/csrc/autograd/generated/VariableType.h``
(available in generated code in PyTorch install dir), you should add the symbolic
-function in ``torch/onnx/symbolic.py`` and follow the instructions listed as below:
+function in ``torch/onnx/symbolic_opset<version>.py`` and follow the instructions listed as below:
-* Define the symbolic function in
- `torch/onnx/symbolic.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic.py>`_.
+* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
+ `torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
Make sure the function has the same name as the ATen operator/function
defined in ``VariableType.h``.
* The first parameter is always the exported ONNX graph.
@@ -303,7 +303,7 @@
Here is an example of handling missing symbolic function for ``elu`` operator.
We try to export the model and see the error message as below::
- UserWarning: ONNX export failed on elu because torch.onnx.symbolic.elu does not exist
+ UserWarning: ONNX export failed on elu because torch.onnx.symbolic_opset9.elu does not exist
RuntimeError: ONNX export failed: Couldn't export operator elu
The export fails because PyTorch does not support exporting ``elu`` operator.
@@ -311,7 +311,7 @@
in ``VariableType.h``. This means ``elu`` is an ATen operator.
We check the `ONNX operator list <http://https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
and confirm that ``Elu`` is standardized in ONNX.
-We add the following lines to ``symbolic.py``::
+We add the following lines to ``symbolic_opset9.py``::
def elu(g, input, alpha, inplace=False):
return g.op("Elu", input, alpha_f=_scalar(alpha))
@@ -319,7 +319,7 @@
Now PyTorch is able to export ``elu`` operator.
There are more examples in
-`symbolic.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic.py>`_,
+`symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_,
`tensor.py <https://github.com/pytorch/pytorch/blob/99037d627da68cdf53d3d0315deceddfadf03bba/torch/autograd/_functions/tensor.py#L24>`_,
`padding.py <https://github.com/pytorch/pytorch/blob/99037d627da68cdf53d3d0315deceddfadf03bba/torch/nn/_functions/padding.py#L8>`_.
diff --git a/test/onnx/expect/TestOperators.test_maxpool_dilations.expect b/test/onnx/expect/TestOperators.test_maxpool_dilations.expect
index 61b6d01..470a5d9 100644
--- a/test/onnx/expect/TestOperators.test_maxpool_dilations.expect
+++ b/test/onnx/expect/TestOperators.test_maxpool_dilations.expect
@@ -7,6 +7,11 @@
output: "1"
op_type: "MaxPool"
attribute {
+ name: "ceil_mode"
+ i: 0
+ type: INT
+ }
+ attribute {
name: "dilations"
ints: 2
type: INTS
diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py
new file mode 100644
index 0000000..37a87af
--- /dev/null
+++ b/test/onnx/test_onnx_opset.py
@@ -0,0 +1,111 @@
+from test_pytorch_common import TestCase, run_tests
+
+import torch
+import torch.onnx
+from torch.nn import Module
+
+import onnx
+
+import io
+
+from torch.onnx.symbolic_helper import _export_onnx_opset_version
+from torch.onnx import ir_version, producer_name, producer_version
+
+
+def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_version):
+ # check_onnx_components
+ assert model.ir_version == ir_version and \
+ model.producer_name == producer_name and \
+ model.producer_version == producer_version and \
+ model.opset_import[0].version == opset_version
+
+ # check the schema with the onnx checker
+ onnx.checker.check_model(model)
+
+ # check target type and attributes
+ graph = model.graph
+ # ops should contain an object for each node
+ # in graph.node, in the right order.
+ # At least the op_name should be specified,
+ # but the op's attributes can optionally be
+ # specified as well
+ assert len(ops) == len(graph.node)
+ for i in range(0, len(ops)):
+ assert graph.node[i].op_type == ops[i]['op_name']
+ if "attributes" in ops[i] :
+ attributes = ops[i]['attributes']
+ assert len(attributes) == len(graph.node[i].attribute)
+ for j in range(0, len(attributes)):
+ for attribute_field in attributes[j].keys():
+ assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
+
+
+def check_onnx_opsets_operator(module, x, ops, opset_versions):
+ for opset_version in opset_versions:
+ f = io.BytesIO()
+ torch.onnx.export(module, x, f, opset_version=opset_version)
+ model = onnx.load(io.BytesIO(f.getvalue()))
+ check_onnx_opset_operator(model, ops[opset_version], opset_version)
+
+
+class TestONNXOpset(TestCase):
+
+ def test_opset_fallback(self):
+ class MyModule(Module):
+ def forward(self, x):
+ return torch.isnan(x)
+
+ ops = [{"op_name" : "IsNaN"},
+ {"op_name" : "Cast", "attributes" : [{"name" : "to", "i" : 2, "type" : 2}]}]
+ ops = {9 : ops, 10 : ops}
+ x = torch.tensor([1.0, float('nan'), 2.0])
+ check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
+
+ def test_topk(self):
+ class MyModule(Module):
+ def forward(self, x):
+ return torch.topk(x, 3)
+
+ ops_9 = [{"op_name" : "TopK", "attributes" : [{"name" : "axis", "i" : -1, "type" : 2},
+ {"name" : "k", "i" : 3, "type" : 2}]}]
+ ops_10 = [{"op_name" : "Constant", "attributes" : [{"name" : "value", "type" : 4}]},
+ {"op_name" : "Unsqueeze", "attributes" : [{"name" : "axes", "ints" : [0], "type" : 7}]},
+ {"op_name" : "TopK", "attributes" : [{"name" : "axis", "i" : -1, "type" : 2}]}]
+ ops = {9 : ops_9, 10 : ops_10}
+ x = torch.arange(1., 6., requires_grad=True)
+ check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
+
+ def test_maxpool(self):
+ module = torch.nn.MaxPool1d(2, stride=1)
+
+ ops_9 = [{"op_name" : "MaxPool",
+ "attributes" :
+ [{"name": "kernel_shape", "ints": [2], "type": 7},
+ {"name": "pads", "ints": [0, 0], "type": 7},
+ {"name": "strides", "ints": [1], "type": 7}]}]
+ ops_10 = [{"op_name" : "MaxPool",
+ "attributes" :
+ [{"name": "ceil_mode", "i": 0, "type": 2},
+ {"name": "kernel_shape", "ints": [2], "type": 7},
+ {"name": "pads", "ints": [0, 0], "type": 7},
+ {"name": "strides", "ints": [1], "type": 7}]}]
+ ops = {9 : ops_9, 10 : ops_10}
+ x = torch.randn(20, 16, 50)
+ check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
+
+ # add test with dilations
+ module = torch.nn.MaxPool1d(2, stride=1, dilation=2)
+
+ ops_10 = [{"op_name" : "MaxPool",
+ "attributes" :
+ [{"name": "ceil_mode", "i": 0, "type": 2},
+ {"name": "dilations", "ints": [2], "type": 7},
+ {"name": "kernel_shape", "ints": [2], "type": 7},
+ {"name": "pads", "ints": [0, 0], "type": 7},
+ {"name": "strides", "ints": [1], "type": 7}]}]
+ ops = {9 : ops_9, 10 : ops_10}
+ x = torch.randn(20, 16, 50)
+ check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
+
+if __name__ == '__main__':
+ run_tests()
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 529b186..7133bfb 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -1574,6 +1574,14 @@
self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3)
+ def test_topk(self):
+ class TopKModel(torch.nn.Module):
+ def forward(self, input):
+ return torch.topk(input, 3)
+ model = TopKModel()
+ x = torch.arange(1., 6.)
+ self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
def test_floor(self):
class FloorModel(torch.nn.Module):
def forward(self, input):
diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
index 4d3ee99..0470669 100644
--- a/test/onnx/test_utility_funs.py
+++ b/test/onnx/test_utility_funs.py
@@ -4,7 +4,7 @@
import torch
import torch.onnx
from torch.onnx import utils
-from torch.onnx.symbolic import _set_opset_version
+from torch.onnx.symbolic_helper import _set_opset_version
import onnx
diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp
index 6cf6794..83e9933 100644
--- a/torch/csrc/jit/passes/onnx.cpp
+++ b/torch/csrc/jit/passes/onnx.cpp
@@ -171,7 +171,8 @@
torch::autograd::SymbolicContext ctx{};
ctx.block = new_block;
py::object onnx = py::module::import("torch.onnx");
- py::object onnx_symbolic = py::module::import("torch.onnx.symbolic");
+ py::object onnx_symbolic = py::module::import("torch.onnx.symbolic_helper");
+ py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
// Returns a node that n maps to in the new graph
auto envFn = [&env](Value* n) -> Value* {
@@ -295,7 +296,6 @@
if (func) {
pyobj = func->get();
}
-
if (!py::hasattr(pyobj, "symbolic")) {
cloneNode(op);
return;
@@ -331,6 +331,8 @@
// Call the symbolic function
// Use a little trampoline function so we can give good error messages
// upon argument mismatch
+ py::object opset_version = onnx_symbolic.attr("_export_onnx_opset_version");
+ onnx_registry.attr("register_op")(op->name(), pyobj.attr("symbolic"), "", opset_version);
py::object raw_output = onnx.attr("_run_symbolic_method")(
op->name(), pyobj.attr("symbolic"), py_symbolic_args);
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 519a10d..4a7b2e5 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -6,6 +6,13 @@
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
+# TODO: Update these variables when there
+# is a new ir_version and producer_version
+# and use these values in the exporter
+ir_version = 4
+producer_name = "pytorch"
+producer_version = "1.1"
+
class ExportTypes:
PROTOBUF_FILE = 1
@@ -58,3 +65,8 @@
def is_in_onnx_export():
from torch.onnx import utils
return utils.is_in_onnx_export()
+
+
+def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+ from torch.onnx import utils
+ return utils.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
new file mode 100644
index 0000000..575043c
--- /dev/null
+++ b/torch/onnx/symbolic_helper.py
@@ -0,0 +1,293 @@
+import torch
+from torch._C import ListType
+import warnings
+
+import torch.onnx
+# This import monkey-patches graph manipulation methods on Graph, used for the
+# ONNX symbolics
+import torch.onnx.utils
+
+from functools import wraps
+
+
+# Note [Edit Symbolic Files]
+# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
+#
+# - These files is ONLY for ATen operators (e.g., operators that show up in the
+# trace as aten::blah). If you need to special case a primitive operator,
+# look at _run_symbolic_function
+# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
+# tensors are always first, then non-tensor arguments.
+# - Parameter names must *exactly* match the names in VariableType.cpp, because
+# dispatch is done with keyword arguments.
+# - Looking for inplace ops? They're detected by the trailing underscore, and
+# transparently dispatched to their non inplace versions in
+# 'run_symbolic_function'. See Note [Export inplace]
+#
+# ----------------------------------------------------------------------------------
+# A note on Tensor types
+# ----------------------------------------------------------------------------------
+#
+# In general, we should avoid depending on the type of Tensor Values contained
+# within the trace graph. However, this is sometimes unavoidable (due to ONNX
+# spec requirements, etc). If you are implementing a symbolic and need Tensor
+# type information, note that there are several levels of Tensor types, defined
+# in aten/src/ATen/core/jit_type.h:
+#
+# TensorType - This is a Tensor, but we don't know anything about its
+# properties (e.g. scalar type, # dims, shapes).
+# Appears as `Tensor` in graph print-outs.
+# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
+# type and number of dimensions, but not the concrete
+# shapes. For example, appears as 'Float(*, *)' in
+# graph print-outs. Useful accessor methods include
+# dim() and scalarType()
+# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
+# concrete sizes in addition to the information
+# contained in TensorTyper. This adds a sizes()
+# method which can be used to retrieve the
+# concrete sizes.
+#
+# In general, we should prefer to rely on the least specific information possible.
+# For example, not relying on tensor properties at all is better than relying
+# on the number of dimensions (DimensionedTensorType) which is better than relying on
+# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
+# more robust to different graphs.
+
+# ---------------------------------------------------------------------------------
+# Helper functions
+# ---------------------------------------------------------------------------------
+
+# Save some builtins as locals, because we'll shadown them below
+_sum = sum
+
+
+def _parse_arg(value, desc):
+ if desc == 'none':
+ return value
+ if desc == 'v' or not _is_value(value):
+ return value
+ if value.node().kind() == 'onnx::Constant':
+ tval = value.node()['value']
+ if desc == 'i':
+ return int(tval)
+ elif desc == 'f':
+ return float(tval)
+ elif desc == 'b':
+ return bool(tval)
+ elif desc == 't':
+ return tval
+ elif desc == 'is':
+ return [int(v) for v in tval]
+ else:
+ raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node")
+ elif value.node().kind() == 'prim::ListConstruct':
+ if desc == 'is':
+ for v in value.node().inputs():
+ if v.node().kind() != 'onnx::Constant':
+ raise RuntimeError("Failed to export an ONNX attribute, "
+ "since it's not constant, please try to make "
+ "things (e.g., kernel size) static if possible")
+ return [int(v.node()['value']) for v in value.node().inputs()]
+ else:
+ raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node")
+
+ raise RuntimeError("Unexpected node type: {}".format(value.node().kind()))
+
+
+def _maybe_get_const(value, desc):
+ if _is_value(value) and value.node().kind() == 'onnx::Constant':
+ return _parse_arg(value, desc)
+ return value
+
+
+def _maybe_get_scalar(value):
+ value_t = _maybe_get_const(value, 't')
+ if isinstance(value_t, torch.Tensor) and value_t.shape == ():
+ return value_t
+ return value
+
+
+def _get_const(value, desc, arg_name):
+ if _is_value(value) and value.node().kind() != 'onnx::Constant':
+ raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
+ return _parse_arg(value, desc)
+
+
+def _unpack_list(list_value):
+ list_node = list_value.node()
+ assert list_node.kind() == "prim::ListConstruct"
+ return list(list_node.inputs())
+
+
+def parse_args(*arg_descriptors):
+ def decorator(fn):
+ def wrapper(g, *args):
+ # some args may be optional, so the length may be smaller
+ assert len(arg_descriptors) >= len(args)
+ args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
+ return fn(g, *args)
+ # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
+ try:
+ wrapper = wraps(fn)(wrapper)
+ except Exception:
+ pass
+ return wrapper
+ return decorator
+
+
+def _scalar(x):
+ """Convert a scalar tensor into a Python value."""
+ assert x.numel() == 1
+ return x.item()
+
+
+def _if_scalar_type_as(g, self, tensor):
+ """
+ Convert self into the same type of tensor, as necessary.
+
+ We only support implicit casting for scalars, so we never
+ actually need to insert an ONNX cast operator here; just
+ fix up the scalar.
+ """
+ if isinstance(self, torch._C.Value):
+ return self
+ elif tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType":
+ ty = tensor.type().scalarType().lower()
+ return getattr(self, ty)()
+ else:
+ return self
+
+
+def _is_value(x):
+ return isinstance(x, torch._C.Value)
+
+
+def _is_tensor_list(x):
+ return x.type().isSubtypeOf(ListType.ofTensors())
+
+
+def _unimplemented(op, msg):
+ warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
+
+
+def _black_list_in_opset(name):
+ def symbolic_fn(*args, **kwargs):
+ warnings.warn("ONNX export failed on {}, which is not yet implemented for opset 10. "
+ "Try exporting with a previous opset version."
+ .format(name))
+ return symbolic_fn
+
+
+def _try_get_scalar_type(*args):
+ for arg in args:
+ try:
+ return arg.type().scalarType()
+ except RuntimeError:
+ pass
+ return None
+
+
+# ---------------------------------------------------------------------
+# ONNX operator version
+# ---------------------------------------------------------------------
+
+# READ ME BEFORE EDITING _default_onnx_opset_version:
+#
+# The variable below controls which ONNX operator set version we are
+# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
+# change occurred in version 8. As long as this variable < 8, you can
+# export models targeting the old behavior. However, if you bump
+# this variable to 8 or later, the breaking change will take into effect:
+# you MUST adjust any symbolic affected by breaking changes. The ONNX
+# spec publishes a *comprehensive* list of BC-breaking changes for every
+# operator revision at:
+#
+# https://github.com/onnx/onnx/blob/master/docs/Changelog.md
+#
+# Please be sure to go through and check all of our implementations here before
+# increasing this number. This includes symbolic definitions NOT in this
+# file, so grep for "OpName" (with quotes)
+#
+# Besides, opset_version can be specified in the invocation of export()
+# and export_to_pretty_string(), and _export_onnx_opset_version will be set
+# and the symbolic functions should check it to determine the behavior
+# of the exporter.
+
+
+_default_onnx_opset_version = 9
+_onnx_master_opset = 10
+_onnx_stable_opsets = [9, 10]
+_export_onnx_opset_version = _default_onnx_opset_version
+
+
+def _set_opset_version(opset_version):
+ global _export_onnx_opset_version
+ if opset_version == _default_onnx_opset_version:
+ _export_onnx_opset_version = opset_version
+ return
+ if opset_version in _onnx_stable_opsets + [_onnx_master_opset]:
+ _export_onnx_opset_version = opset_version
+ return
+ raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
+
+
+# Metaprogram symbolics for each ATen native specialized cast operator.
+# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
+# ONNX cast node with `to` attribute 'UINT8'
+#
+# TODO: remove these once we support Type's in the JIT IR and we can once again
+# use the unified toType operator
+cast_pytorch_to_onnx = {
+ 'Byte': torch.onnx.TensorProtoDataType.UINT8,
+ 'Char': torch.onnx.TensorProtoDataType.INT8,
+ 'Double': torch.onnx.TensorProtoDataType.DOUBLE,
+ 'Float': torch.onnx.TensorProtoDataType.FLOAT,
+ 'Half': torch.onnx.TensorProtoDataType.FLOAT16,
+ 'Int': torch.onnx.TensorProtoDataType.INT32,
+ 'Long': torch.onnx.TensorProtoDataType.INT64,
+ 'Short': torch.onnx.TensorProtoDataType.INT16,
+}
+
+
+scalar_name_to_pytorch = {
+ 'uint8_t': 'Byte',
+ 'int8_t': 'Char',
+ 'double': 'Double',
+ 'float': 'Float',
+ 'half': 'Half',
+ 'int': 'Int',
+ 'int64_t': 'Long',
+ 'int16_t': 'Short',
+}
+
+
+# This indicates each scalar type's corresponding
+# torch type. Related source:
+# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
+scalar_type_to_pytorch_type = [
+ torch.uint8, # 0
+ torch.int8, # 1
+ torch.short, # 2
+ torch.int, # 3
+ torch.int64, # 4
+ torch.half, # 5
+ torch.float, # 6
+ torch.double, # 7
+]
+
+
+def _cast_func_template(to_i, g, input, non_blocking):
+ return g.op("Cast", input, to_i=to_i)
+
+
+scalar_type_to_onnx = [
+ cast_pytorch_to_onnx["Byte"],
+ cast_pytorch_to_onnx["Char"],
+ cast_pytorch_to_onnx["Short"],
+ cast_pytorch_to_onnx["Int"],
+ cast_pytorch_to_onnx["Long"],
+ cast_pytorch_to_onnx["Half"],
+ cast_pytorch_to_onnx["Float"],
+ cast_pytorch_to_onnx["Double"],
+]
diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py
new file mode 100644
index 0000000..ee75d7e
--- /dev/null
+++ b/torch/onnx/symbolic_opset10.py
@@ -0,0 +1,122 @@
+import torch
+from torch.nn.modules.utils import _single, _pair, _triple
+import torch.onnx
+# This import monkey-patches graph manipulation methods on Graph, used for the
+# ONNX symbolics
+import torch.onnx.utils
+
+from torch.onnx.symbolic_helper import parse_args, _unimplemented, _black_list_in_opset
+import torch.onnx.symbolic_opset9
+
+
+# EDITING THIS FILE? READ THIS FIRST!
+# see Note [Edit Symbolic Files] in symbolic_helper.py
+
+# This file exports ONNX ops for opset 10
+# Opset 10 is supported by ONNX release 1.5.0
+# release on 04/24/19
+
+
+# Blacklist operators for this opset version.
+# These operators have been updated in ONNX but not re-implemented here.
+# It is very important to blacklist these operators to avoid exporting
+# models with mixed versions of operators.
+# TODO : add support for the blacklisted operators in black_listed_operators
+black_listed_operators = ["flip",
+ "slice",
+ "upsample_nearest2d", "upsample_bilinear2d",
+ "dropout", "feature_dropout", "alpha_dropout", "feature_alpha_dropout",
+ "dropout_", "feature_dropout_", "alpha_dropout_", "feature_alpha_dropout_"]
+
+for black_listed_op in black_listed_operators:
+ vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
+
+
+# Add new operator here
+@parse_args('v', 'i', 'i', 'i', 'i')
+def topk(g, self, k, dim, largest, sorted, out=None):
+ if out is not None:
+ _unimplemented("TopK", "Out parameter is not supported for topk")
+ if not largest:
+ _unimplemented("TopK", "Ascending TopK is not supported")
+ k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64))
+ from torch.onnx.symbolic_opset9 import unsqueeze
+ k = unsqueeze(g, k, 0)
+ return g.op("TopK", self, k, axis_i=dim, outputs=2)
+
+
+def _max_pool(name, tuple_fn, ndims, return_indices):
+ @parse_args('v', 'is', 'is', 'is', 'is', 'i')
+ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
+ if not stride:
+ stride = kernel_size
+ kwargs = {
+ 'kernel_shape_i': tuple_fn(kernel_size),
+ 'pads_i': tuple_fn(padding) * 2,
+ 'strides_i': tuple_fn(stride),
+ 'ceil_mode_i': ceil_mode,
+ }
+ if set(tuple_fn(dilation)) != {1}:
+ kwargs['dilations_i'] = tuple_fn(dilation)
+ # easy but hacky way to get flattened indices values
+ # to be used to convert the indices values to non-flattened.
+ # In ONNX the indices are computed as a flatten 1-D tensor,
+ # so the values in indices are in [0, N x C x D1 x ... x Dn).
+ # To convert the indices to the same format used by Pytorch,
+ # we first execute a maxpool with a kernel and stride of 1 on the same input.
+ # This will result in a tensor of indices in which each index will have it's own value.
+ # Using this tensor as a reference, we extract the first index of each axis and substract
+ # it from each index of this axis in the indices to convert.
+ # This step will result in a tensor were each dimension has values of indices within
+ # the dimension it is in.
+ # For more information :
+ # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
+ if return_indices:
+ r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
+ _, flattened_indices = g.op("MaxPool", input, outputs=2,
+ kernel_shape_i=[1 for _ in range(ndims)],
+ strides_i=[1 for _ in range(ndims)])
+ # convert indices to have non-flattened indices values
+ s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
+ starts=tuple_fn(0), ends=tuple_fn(1))
+ indices = sub(g, indices, s)
+ return r, indices
+ else:
+ r = g.op("MaxPool", input, outputs=1, **kwargs)
+ return r
+
+ return symbolic_fn
+
+
+max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
+max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
+max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
+max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
+max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
+max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
+
+
+def _avg_pool(name, tuple_fn):
+ @parse_args('v', 'is', 'is', 'is', 'i', 'i')
+ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
+ if not stride:
+ stride = kernel_size
+ padding = tuple(tuple_fn(padding))
+ if count_include_pad:
+ input = g.op("Pad", input,
+ pads_i=((0,) * 2 + padding) * 2,
+ mode_s='constant',
+ value_f=0.)
+ padding = (0,) * len(padding)
+ output = g.op("AveragePool", input,
+ kernel_shape_i=tuple_fn(kernel_size),
+ strides_i=tuple_fn(stride),
+ pads_i=padding * 2,
+ ceil_mode_i=ceil_mode)
+ return output
+ return symbolic_fn
+
+
+avg_pool1d = _avg_pool('avg_pool1d', _single)
+avg_pool2d = _avg_pool('avg_pool2d', _pair)
+avg_pool3d = _avg_pool('avg_pool3d', _triple)
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic_opset9.py
similarity index 77%
rename from torch/onnx/symbolic.py
rename to torch/onnx/symbolic_opset9.py
index 6c0be5c..3a17b95 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1,233 +1,28 @@
import torch
from torch._C import ListType, OptionalType
from torch.nn.modules.utils import _single, _pair, _triple
-import warnings
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
-from functools import partial, wraps
+from functools import partial
+
+import torch.onnx.symbolic_helper as sym_help
+from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
import numpy
import math
+import warnings
+
# EDITING THIS FILE? READ THIS FIRST!
-#
-# - This file is ONLY for ATen operators (e.g., operators that show up in the
-# trace as aten::blah). If you need to special case a primitive operator,
-# look at _run_symbolic_function
-# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
-# tensors are always first, then non-tensor arguments.
-# - Parameter names must *exactly* match the names in VariableType.cpp, because
-# dispatch is done with keyword arguments.
-# - Looking for inplace ops? They're detected by the trailing underscore, and
-# transparently dispatched to their non inplace versions in
-# 'run_symbolic_function'. See Note [Export inplace]
-#
-# ----------------------------------------------------------------------------------
-# A note on Tensor types
-# ----------------------------------------------------------------------------------
-#
-# In general, we should avoid depending on the type of Tensor Values contained
-# within the trace graph. However, this is sometimes unavoidable (due to ONNX
-# spec requirements, etc). If you are implementing a symbolic and need Tensor
-# type information, note that there are several levels of Tensor types, defined
-# in aten/src/ATen/core/jit_type.h:
-#
-# TensorType - This is a Tensor, but we don't know anything about its
-# properties (e.g. scalar type, # dims, shapes).
-# Appears as `Tensor` in graph print-outs.
-# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
-# type and number of dimensions, but not the concrete
-# shapes. For example, appears as 'Float(*, *)' in
-# graph print-outs. Useful accessor methods include
-# dim() and scalarType()
-# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
-# concrete sizes in addition to the information
-# contained in TensorTyper. This adds a sizes()
-# method which can be used to retrieve the
-# concrete sizes.
-#
-# In general, we should prefer to rely on the least specific information possible.
-# For example, not relying on tensor properties at all is better than relying
-# on the number of dimensions (DimensionedTensorType) which is better than relying on
-# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
-# more robust to different graphs.
+# see Note [Edit Symbolic Files] in symbolic_helper.py
-# ---------------------------------------------------------------------------------
-# Helper functions
-# ---------------------------------------------------------------------------------
-
-# Save some builtins as locals, because we'll shadown them below
-_sum = sum
-
-
-def _parse_arg(value, desc):
- if desc == 'none':
- return value
- if desc == 'v' or not _is_value(value):
- return value
- if value.node().kind() == 'onnx::Constant':
- tval = value.node()['value']
- if desc == 'i':
- return int(tval)
- elif desc == 'f':
- return float(tval)
- elif desc == 'b':
- return bool(tval)
- elif desc == 't':
- return tval
- elif desc == 'is':
- return [int(v) for v in tval]
- else:
- raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node")
- elif value.node().kind() == 'prim::ListConstruct':
- if desc == 'is':
- for v in value.node().inputs():
- if v.node().kind() != 'onnx::Constant':
- raise RuntimeError("Failed to export an ONNX attribute, "
- "since it's not constant, please try to make "
- "things (e.g., kernel size) static if possible")
- return [int(v.node()['value']) for v in value.node().inputs()]
- else:
- raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node")
-
- raise RuntimeError("Unexpected node type: {}".format(value.node().kind()))
-
-
-def _maybe_get_const(value, desc):
- if _is_value(value) and value.node().kind() == 'onnx::Constant':
- return _parse_arg(value, desc)
- return value
-
-
-def _maybe_get_scalar(value):
- value_t = _maybe_get_const(value, 't')
- if isinstance(value_t, torch.Tensor) and value_t.shape == ():
- return value_t
- return value
-
-
-def _get_const(value, desc, arg_name):
- if _is_value(value) and value.node().kind() != 'onnx::Constant':
- raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
- return _parse_arg(value, desc)
-
-
-def _unpack_list(list_value):
- list_node = list_value.node()
- assert list_node.kind() == "prim::ListConstruct"
- return list(list_node.inputs())
-
-
-def parse_args(*arg_descriptors):
- def decorator(fn):
- def wrapper(g, *args):
- # some args may be optional, so the length may be smaller
- assert len(arg_descriptors) >= len(args)
- args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
- return fn(g, *args)
- # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
- try:
- wrapper = wraps(fn)(wrapper)
- except Exception:
- pass
- return wrapper
- return decorator
-
-
-def _scalar(x):
- """Convert a scalar tensor into a Python value."""
- assert x.numel() == 1
- return x.item()
-
-
-def _if_scalar_type_as(g, self, tensor):
- """
- Convert self into the same type of tensor, as necessary.
-
- We only support implicit casting for scalars, so we never
- actually need to insert an ONNX cast operator here; just
- fix up the scalar.
- """
- if isinstance(self, torch._C.Value):
- return self
- elif tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType":
- ty = tensor.type().scalarType().lower()
- return getattr(self, ty)()
- else:
- return self
-
-
-def _is_value(x):
- return isinstance(x, torch._C.Value)
-
-
-def _is_tensor_list(x):
- return x.type().isSubtypeOf(ListType.ofTensors())
-
-
-def _unimplemented(op, msg):
- warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
-
-
-def _try_get_scalar_type(*args):
- for arg in args:
- try:
- return arg.type().scalarType()
- except RuntimeError:
- pass
- return None
-
-
-# ---------------------------------------------------------------------
-# ONNX operator version
-# ---------------------------------------------------------------------
-
-# READ ME BEFORE EDITING _default_onnx_opset_version:
-#
-# The variable below controls which ONNX operator set version we are
-# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
-# change occurred in version 8. As long as this variable < 8, you can
-# export models targeting the old behavior. However, if you bump
-# this variable to 8 or later, the breaking change will take into effect:
-# you MUST adjust any symbolic affected by breaking changes. The ONNX
-# spec publishes a *comprehensive* list of BC-breaking changes for every
-# operator revision at:
-#
-# https://github.com/onnx/onnx/blob/master/docs/Changelog.md
-#
-# Please be sure to go through and check all of our implementations here before
-# increasing this number. This includes symbolic definitions NOT in this
-# file, so grep for "OpName" (with quotes)
-#
-# Besides, opset_version can be specified in the invocation of export()
-# and export_to_pretty_string(), and _export_onnx_opset_version will be set
-# and the symbolic functions should check it to determine the behavior
-# of the exporter.
-
-
-_default_onnx_opset_version = 9
-_onnx_master_opset = 10
-_onnx_stable_opsets = [9]
-_export_onnx_opset_version = _default_onnx_opset_version
-
-
-def _set_opset_version(opset_version):
- global _export_onnx_opset_version
- if opset_version == _export_onnx_opset_version:
- return
- if opset_version in _onnx_stable_opsets + [_onnx_master_opset]:
- _export_onnx_opset_version = opset_version
- return
- raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
-
-
-# ---------------------------------------------------------------------
-# Symbolic definitions
-# ---------------------------------------------------------------------
+# This file exports ONNX ops for opset 9
+# Opset 9 is supported by ONNX release 1.4.1
+# release on 01/23/19
# Note [Pointwise by scalar]
@@ -284,60 +79,60 @@
def add(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
- if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
+ if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("add", "alpha != 1")
# See Note [Pointwise by scalar]
- other = _maybe_get_scalar(other)
- return g.op("Add", self, _if_scalar_type_as(g, other, self))
+ other = sym_help._maybe_get_scalar(other)
+ return g.op("Add", self, sym_help._if_scalar_type_as(g, other, self))
def sub(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
- if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
+ if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("sub", "alpha != 1")
# See Note [Pointwise by scalar]. Note that self or other may be scalars.
- other = _maybe_get_scalar(other)
- return g.op("Sub", self, _if_scalar_type_as(g, other, self))
+ other = sym_help._maybe_get_scalar(other)
+ return g.op("Sub", self, sym_help._if_scalar_type_as(g, other, self))
def rsub(g, self, other, alpha=None):
- other = _maybe_get_scalar(other)
- other = _if_scalar_type_as(g, other, self)
+ other = sym_help._maybe_get_scalar(other)
+ other = sym_help._if_scalar_type_as(g, other, self)
return sub(g, other, self, alpha=alpha)
def mul(g, self, other):
# See Note [Pointwise by scalar]
- other = _maybe_get_scalar(other)
- return g.op("Mul", self, _if_scalar_type_as(g, other, self))
+ other = sym_help._maybe_get_scalar(other)
+ return g.op("Mul", self, sym_help._if_scalar_type_as(g, other, self))
def div(g, self, other):
# See Note [Pointwise by scalar]
- other = _maybe_get_scalar(other)
- return g.op("Div", self, _if_scalar_type_as(g, other, self))
+ other = sym_help._maybe_get_scalar(other)
+ return g.op("Div", self, sym_help._if_scalar_type_as(g, other, self))
def reciprocal(g, self):
- return g.op("Div", _if_scalar_type_as(g, torch.ones(1), self), self)
+ return g.op("Div", sym_help._if_scalar_type_as(g, torch.ones(1), self), self)
@parse_args('v', 'i')
def cat(g, tensor_list, dim):
- tensors = _unpack_list(tensor_list)
+ tensors = sym_help._unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
@parse_args('v', 'i')
def stack(g, tensor_list, dim):
- unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in _unpack_list(tensor_list)]
+ unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in sym_help._unpack_list(tensor_list)]
return g.op("Concat", *unsqueezed, axis_i=dim)
def mm(g, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
- ty = _try_get_scalar_type(self, other).lower()
+ ty = sym_help._try_get_scalar_type(self, other).lower()
C = g.constant(0, [1], ty)
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
@@ -352,7 +147,7 @@
@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
- return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha))
+ return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
def neg(g, self):
@@ -409,7 +204,7 @@
return g.op(onnx_op_name, self, keepdims_i=0)
else:
# dim-reduce path
- dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim')
+ dim, keepdim = sym_help._get_const(dim, 'i', 'dim'), sym_help._get_const(keepdim, 'i', 'keepdim')
return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
return symbolic
@@ -442,8 +237,8 @@
def expand(g, self, size, implicit):
- size = _maybe_get_const(size, 'is')
- if not _is_value(size):
+ size = sym_help._maybe_get_const(size, 'is')
+ if not sym_help._is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Expand", self, size)
@@ -504,8 +299,8 @@
def view(g, self, size):
- size = _maybe_get_const(size, 'is')
- if _is_value(size):
+ size = sym_help._maybe_get_const(size, 'is')
+ if sym_help._is_value(size):
shape = size
else:
if self.isCompleteTensor():
@@ -570,7 +365,7 @@
if size == 1:
dims.append(i)
else:
- dims = [_get_const(dim, 'i', 'dim')]
+ dims = [sym_help._get_const(dim, 'i', 'dim')]
# Handle negative dims
for i, dim in enumerate(dims):
if dim < 0:
@@ -607,18 +402,18 @@
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
- if _scalar(threshold) != 0:
+ if sym_help._scalar(threshold) != 0:
return _unimplemented("threshold", "non-zero threshold")
- if _scalar(value) != 0:
+ if sym_help._scalar(value) != 0:
return _unimplemented("threshold", "non-zero value")
return g.op("Relu", self)
def leaky_relu(g, input, negative_slope, inplace=False):
- negative_slope = _get_const(negative_slope, 't', 'negative_slope')
+ negative_slope = sym_help._get_const(negative_slope, 't', 'negative_slope')
# See Note [Export inplace]
# TODO: Talk to ONNX about unconditional cast of scalar to float
- return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope))
+ return g.op("LeakyRelu", input, alpha_f=sym_help._scalar(negative_slope))
@parse_args('v', 'i')
@@ -655,13 +450,13 @@
if input.type().dim() == dim + 1:
softmax = g.op('Softmax', input, axis_i=dim)
if dtype:
- softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+ softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[dtype])
return softmax
exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
if dtype:
- softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+ softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[dtype])
return softmax
@@ -700,6 +495,8 @@
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode and input.type().kind() != "CompleteTensorType":
return _unimplemented(name, "input size not accesible")
+ if set(tuple_fn(dilation)) != {1}:
+ return _unimplemented(name, "dilation")
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
@@ -713,8 +510,6 @@
'pads_i': padding,
'strides_i': tuple_fn(stride),
}
- if set(tuple_fn(dilation)) != {1}:
- kwargs['dilations_i'] = tuple_fn(dilation)
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
@@ -888,7 +683,7 @@
def wrap_logical_op_with_cast_to_uint8(func):
def wrap_with_cast(g, input, other):
- return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
+ return g.op("Cast", func(g, input, other), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
return wrap_with_cast
@@ -915,8 +710,8 @@
def gt_impl(g, input, other):
- other = _maybe_get_scalar(other)
- return g.op("Greater", input, _if_scalar_type_as(g, other, input))
+ other = sym_help._maybe_get_scalar(other)
+ return g.op("Greater", input, sym_help._if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@@ -925,22 +720,22 @@
def lt_impl(g, input, other):
- other = _maybe_get_scalar(other)
- return g.op("Less", input, _if_scalar_type_as(g, other, input))
+ other = sym_help._maybe_get_scalar(other)
+ return g.op("Less", input, sym_help._if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def ge(g, input, other):
- other = _maybe_get_scalar(other)
- return lt_impl(g, input, _if_scalar_type_as(g, other, input))
+ other = sym_help._maybe_get_scalar(other)
+ return lt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))
@wrap_logical_op_with_cast_to_uint8
@wrap_logical_op_with_negation
def le(g, input, other):
- other = _maybe_get_scalar(other)
- return gt_impl(g, input, _if_scalar_type_as(g, other, input))
+ other = sym_help._maybe_get_scalar(other)
+ return gt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))
def where(g, condition, self, other):
@@ -957,7 +752,7 @@
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype:
- return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
+ return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[dtype])
return return_op
@@ -1064,7 +859,7 @@
if input_scale and input_scale != 1.:
return _unimplemented("input_scale", "does not support input_scale in Elu")
# See Note [Export inplace]
- return g.op("Elu", input, alpha_f=_scalar(alpha))
+ return g.op("Elu", input, alpha_f=sym_help._scalar(alpha))
def selu(g, input):
@@ -1077,7 +872,7 @@
def index_put(g, self, indices_list_value, values, accumulate):
- indices_list = _unpack_list(indices_list_value)
+ indices_list = sym_help._unpack_list(indices_list_value)
args = [self] + indices_list + [values, accumulate]
return g.op("ATen", *args, operator_s='index_put')
@@ -1088,7 +883,7 @@
if other.isCompleteTensor():
other_type_name = other.type().scalarType()
- return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
+ return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other_type_name])
else:
# We don't know the type of other, bail by emitting ATen
return g.op("ATen", self, other, operator_s="type_as")
@@ -1114,8 +909,8 @@
def pow(g, self, exponent):
- exponent = _maybe_get_scalar(exponent)
- return g.op("Pow", self, _if_scalar_type_as(g, exponent, self))
+ exponent = sym_help._maybe_get_scalar(exponent)
+ return g.op("Pow", self, sym_help._if_scalar_type_as(g, exponent, self))
def clamp(g, self, min, max):
@@ -1152,8 +947,8 @@
return g.op("Max", self, dim_or_y)
# torch.max(input, dim, keepdim)
else:
- dim = _get_const(dim_or_y, 'i', 'dim')
- keepdim = _get_const(keepdim, 'i', 'keepdim')
+ dim = sym_help._get_const(dim_or_y, 'i', 'dim')
+ keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op('ArgMax', self, axis_i=dim, keepdims_i=keepdim)
return max, indices
@@ -1168,8 +963,8 @@
return g.op("Min", self, dim_or_y)
# torch.min(input, dim, keepdim)
else:
- dim = _get_const(dim_or_y, 'i', 'dim')
- keepdim = _get_const(keepdim, 'i', 'keepdim')
+ dim = sym_help._get_const(dim_or_y, 'i', 'dim')
+ keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op('ArgMin', self, axis_i=dim, keepdims_i=keepdim)
return min, indices
@@ -1191,7 +986,6 @@
@parse_args('v', 'f', 'i')
def feature_dropout(g, input, p, train):
# NB: In inference mode, FeatureDropout is exported as an identity op.
- from torch.onnx.symbolic import _unimplemented
if train:
return _unimplemented(name, "training mode")
return input
@@ -1238,69 +1032,9 @@
outputs=3)
-# Metaprogram symbolics for each ATen native specialized cast operator.
-# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
-# ONNX cast node with `to` attribute 'UINT8'
-#
-# TODO: remove these once we support Type's in the JIT IR and we can once again
-# use the unified toType operator
-cast_pytorch_to_onnx = {
- 'Byte': torch.onnx.TensorProtoDataType.UINT8,
- 'Char': torch.onnx.TensorProtoDataType.INT8,
- 'Double': torch.onnx.TensorProtoDataType.DOUBLE,
- 'Float': torch.onnx.TensorProtoDataType.FLOAT,
- 'Half': torch.onnx.TensorProtoDataType.FLOAT16,
- 'Int': torch.onnx.TensorProtoDataType.INT32,
- 'Long': torch.onnx.TensorProtoDataType.INT64,
- 'Short': torch.onnx.TensorProtoDataType.INT16,
-}
-
-scalar_name_to_pytorch = {
- 'uint8_t': 'Byte',
- 'int8_t': 'Char',
- 'double': 'Double',
- 'float': 'Float',
- 'half': 'Half',
- 'int': 'Int',
- 'int64_t': 'Long',
- 'int16_t': 'Short',
-}
-
-
-# This indicates each scalar type's corresponding
-# torch type. Related source:
-# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
-scalar_type_to_pytorch_type = [
- torch.uint8, # 0
- torch.int8, # 1
- torch.short, # 2
- torch.int, # 3
- torch.int64, # 4
- torch.half, # 5
- torch.float, # 6
- torch.double, # 7
-]
-
-
-def _cast_func_template(to_i, g, input, non_blocking):
- return g.op("Cast", input, to_i=to_i)
-
-
-for k, v in cast_pytorch_to_onnx.items():
+for k, v in sym_help.cast_pytorch_to_onnx.items():
name = '_cast_{}'.format(k)
- globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v))
-
-
-scalar_type_to_onnx = [
- cast_pytorch_to_onnx["Byte"],
- cast_pytorch_to_onnx["Char"],
- cast_pytorch_to_onnx["Short"],
- cast_pytorch_to_onnx["Int"],
- cast_pytorch_to_onnx["Long"],
- cast_pytorch_to_onnx["Half"],
- cast_pytorch_to_onnx["Float"],
- cast_pytorch_to_onnx["Double"],
-]
+ globals()[name] = parse_args('v', 'i')(partial(sym_help._cast_func_template, v))
@parse_args('v', 'i', 'v', 'v', 'b')
@@ -1309,7 +1043,7 @@
raise RuntimeError("onnx pin_memory support is not implemented")
# NOTE: no way to set device and layout in ONNX, so we ignore it
return g.op("ConstantOfShape", sizes,
- value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
+ value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v', 'b')
@@ -1318,7 +1052,7 @@
raise RuntimeError("onnx pin_memory support is not implemented")
shape = g.op("Shape", input)
return g.op("ConstantOfShape", shape,
- value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
+ value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v', 'b')
@@ -1326,7 +1060,7 @@
if pin_memory:
raise RuntimeError("onnx pin_memory support is not implemented")
return g.op("ConstantOfShape", sizes,
- value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
+ value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v', 'b')
@@ -1335,20 +1069,20 @@
raise RuntimeError("onnx pin_memory support is not implemented")
shape = g.op("Shape", input)
return g.op("ConstantOfShape", shape,
- value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
+ value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
if pin_memory and _parse_arg(pin_memory, 'b'):
raise RuntimeError("onnx pin_memory support is not implemented")
- const_value = _maybe_get_const(value, 't')
- if _is_value(const_value):
+ const_value = sym_help._maybe_get_const(value, 't')
+ if sym_help._is_value(const_value):
tmp = zeros(sizes, dtype, layout, device)
return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
- dtype = _get_const(dtype, 'i', 'dtype')
+ dtype = sym_help._get_const(dtype, 'i', 'dtype')
return g.op("ConstantOfShape", sizes,
- value_t=torch.tensor([const_value], dtype=scalar_type_to_pytorch_type[dtype]))
+ value_t=torch.tensor([const_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'f', 'i', 'v', 'v', 'b')
@@ -1357,7 +1091,7 @@
raise RuntimeError("onnx pin_memory support is not implemented")
shape = g.op("Shape", input)
return g.op("ConstantOfShape", shape,
- value_t=torch.tensor([fill_value], dtype=scalar_type_to_pytorch_type[dtype]))
+ value_t=torch.tensor([fill_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'v', 'v', 'v', 'i')
@@ -1422,32 +1156,32 @@
return self
else:
# aten::to(Tensor, ScalarType, bool, bool)
- dtype = _get_const(args[0], 'i', 'dtype')
- return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+ dtype = sym_help._get_const(args[0], 'i', 'dtype')
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 4:
# aten::to(Tensor, Device, ScalarType, bool, bool)
- dtype = _get_const(args[1], 'i', 'dtype')
- return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+ dtype = sym_help._get_const(args[1], 'i', 'dtype')
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 5:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
- dtype = _get_const(args[0], 'i', 'dtype')
+ dtype = sym_help._get_const(args[0], 'i', 'dtype')
# Layout and device are ignored
- return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 6:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool) -> Tensor
- dtype = _get_const(args[0], 'i', 'dtype')
+ dtype = sym_help._get_const(args[0], 'i', 'dtype')
# Layout and device are ignored
- return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+ return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::to signature")
def repeat(g, self, repeats):
- if not _is_value(repeats):
+ if not sym_help._is_value(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
- const_repeats = _maybe_get_const(repeats, 'is')
+ const_repeats = sym_help._maybe_get_const(repeats, 'is')
- if self.isCompleteTensor() and not _is_value(const_repeats):
+ if self.isCompleteTensor() and not sym_help._is_value(const_repeats):
sizes = self.type().sizes()
diff_dims = len(const_repeats) - len(sizes)
if diff_dims > 0:
@@ -1600,20 +1334,20 @@
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
- hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
+ hidden, weight = sym_help._unpack_list(hidden_v), sym_help._unpack_list(weight_v)
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
- hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
+ hidden, weight = sym_help._unpack_list(hidden_v), sym_help._unpack_list(weight_v)
return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def lstm(g, *args):
- if _is_tensor_list(args[3]):
+ if sym_help._is_tensor_list(args[3]):
return _lstm_packed(g, *args)
else:
return _lstm_full(g, *args)
@@ -1622,18 +1356,18 @@
def _one_hidden_rnn(kind):
@parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
- weight = _unpack_list(weight_v)
+ weight = sym_help._unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
- weight = _unpack_list(weight_v)
+ weight = sym_help._unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def symbolic(g, *args):
- if _is_tensor_list(args[3]):
+ if sym_help._is_tensor_list(args[3]):
return _rnn_packed(g, *args)
else:
return _rnn_full(g, *args)
@@ -1692,7 +1426,7 @@
def randn(g, *shapes):
shapes_list = list(shapes)
- shape = _maybe_get_const(shapes_list[0], "is")
+ shape = sym_help._maybe_get_const(shapes_list[0], "is")
return g.op('RandomNormal', shape_i=shape)
@@ -1750,7 +1484,7 @@
@parse_args('v')
def isnan(g, input):
output = g.op('IsNaN', input)
- output = _cast_func_template(cast_pytorch_to_onnx['Byte'], g, output, None)
+ output = sym_help._cast_func_template(sym_help.cast_pytorch_to_onnx['Byte'], g, output, None)
return output
diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py
new file mode 100644
index 0000000..4d4d272
--- /dev/null
+++ b/torch/onnx/symbolic_registry.py
@@ -0,0 +1,70 @@
+import warnings
+import importlib
+from inspect import getmembers, isfunction
+
+# The symbolic registry "_registry" is a dictionary that maps operators
+# (for a specific domain and opset version) to their symbolic functions.
+# An operator is defined by its domain, opset version, and opname.
+# The keys are tuples (domain, version), (where domain is a string, and version is an int),
+# and the operator's name (string).
+# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
+_registry = {}
+
+_symbolic_versions = {}
+from torch.onnx.symbolic_helper import _onnx_stable_opsets
+for opset_version in _onnx_stable_opsets:
+ module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
+ _symbolic_versions[opset_version] = module
+
+def register_version(domain, version):
+ if not is_registered_version(domain, version):
+ global _registry
+ _registry[(domain, version)] = {}
+ register_ops_in_version(domain, version)
+
+
+def register_ops_in_version(domain, version):
+ # iterates through the symbolic functions of
+ # the specified opset version, and the previous
+ # opset versions for operators supported in
+ # previous versions
+ iter_version = version
+ while iter_version >= 9:
+ version_ops = get_ops_in_version(iter_version)
+ for op in version_ops:
+ if isfunction(op[1]) and \
+ not is_registered_op(op[0], domain, version):
+ register_op(op[0], op[1], domain, version)
+ iter_version = iter_version - 1
+
+
+def get_ops_in_version(version):
+ return getmembers(_symbolic_versions[version])
+
+
+def is_registered_version(domain, version):
+ global _registry
+ return (domain, version) in _registry
+
+
+def register_op(opname, op, domain, version):
+ if domain is None or version is None:
+ warnings.warn("ONNX export failed. The ONNX domain and/or version to register are None.")
+ global _registry
+ if not is_registered_version(domain, version):
+ _registry[(domain, version)] = {}
+ _registry[(domain, version)][opname] = op
+
+
+def is_registered_op(opname, domain, version):
+ if domain is None or version is None:
+ warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
+ global _registry
+ return (domain, version) in _registry and opname in _registry[(domain, version)]
+
+
+def get_registered_op(opname, domain, version):
+ if domain is None or version is None:
+ warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
+ global _registry
+ return _registry[(domain, version)][opname]
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index b5f59d5..0a8b884 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -93,7 +93,7 @@
output nodes of the graph, in order
aten (bool, default False): [DEPRECATED. use operator_export_type] export the
model in aten mode. If using aten mode, all the ops original exported
- by the functions in symbolic.py are exported as ATen ops.
+ by the functions in symbolic_opset<version>.py are exported as ATen ops.
export_raw_ir (bool, default False): [DEPRECATED. use operator_export_type]
export the internal IR directly instead of converting it to ONNX ops.
operator_export_type (enum, default OperatorExportTypes.ONNX):
@@ -107,7 +107,7 @@
evolve before next stable release, by default we export to one stable
opset version. Right now, supported stable opset version is 9.
The opset_version must be _onnx_master_opset or in _onnx_stable_opsets
- which are defined in torch/onnx/symbolic.py
+ which are defined in torch/onnx/symbolic_helper.py
do_constant_folding (bool, default False): If True, the constant-folding
optimization is applied to the model during export. Constant-folding
optimization will replace some of the ops that have all constant
@@ -237,7 +237,7 @@
example_outputs=None, propagate=False,
_retain_param_name=False, do_constant_folding=False,
_disable_torch_constant_prop=False):
- from torch.onnx.symbolic import _export_onnx_opset_version
+ from torch.onnx.symbolic_helper import _export_onnx_opset_version
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):
args = (args, )
@@ -326,7 +326,7 @@
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
google_printer=False, opset_version=None, _retain_param_name=False,
do_constant_folding=False):
- from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
+ from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
if opset_version is None:
opset_version = _default_onnx_opset_version
_set_opset_version(opset_version)
@@ -352,7 +352,7 @@
assert __IN_ONNX_EXPORT is False
__IN_ONNX_EXPORT = True
try:
- from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
+ from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
if opset_version is None:
opset_version = _default_onnx_opset_version
_set_opset_version(opset_version)
@@ -553,7 +553,10 @@
# NB: Returning None means the node gets cloned as is into
# the new graph
try:
- import torch.onnx.symbolic
+ from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
+ import torch.onnx.symbolic_registry as sym_registry
+
+ sym_registry.register_version('', opset_version)
# See Note [Export inplace]
# TODO: I think this is not necessary anymore
@@ -568,7 +571,7 @@
return None
elif ns == "aten":
- is_exportable_aten_op = hasattr(torch.onnx.symbolic, op_name)
+ is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version)
is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export):
@@ -582,11 +585,11 @@
# Export it regularly
attrs = {k: n[k] for k in n.attributeNames()}
if not is_exportable_aten_op:
- warnings.warn("ONNX export failed on ATen operator {} because torch.onnx.symbolic.{} does not exist"
- .format(op_name, op_name))
- return None
- fn = getattr(torch.onnx.symbolic, op_name)
- return fn(g, *inputs, **attrs)
+ warnings.warn("ONNX export failed on ATen operator {} because "
+ "torch.onnx.symbolic_opset{}.{} does not exist"
+ .format(op_name, opset_version, op_name))
+ op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
+ return op_fn(g, *inputs, **attrs)
elif ns == "prim":
if op_name == "Constant" and not n.mustBeNone():
@@ -614,17 +617,32 @@
torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
return new_op_outputs
else:
+ # TODO: we sould lift prim's symbolic out
symbolic_name = 'prim_' + op_name
- symbolic_fn = getattr(torch.onnx.symbolic, symbolic_name, None)
- if symbolic_fn is None:
+ is_exportable = sym_registry.is_registered_op(symbolic_name, '', opset_version)
+ if not is_exportable:
warnings.warn("ONNX export failed on primitive operator {}; please report a bug".format(op_name))
- return None
+ symbolic_fn = sym_registry.get_registered_op(symbolic_name, '', opset_version)
attrs = {k: n[k] for k in n.attributeNames()}
return symbolic_fn(g, *inputs, **attrs)
+ # custom ops
+ elif sym_registry.is_registered_version(ns, opset_version):
+ if not sym_registry.is_registered_op(op_name, ns, opset_version):
+ warnings.warn("ONNX export failed on custom operator {}::{} because "
+ "torch.onnx.symbolic_opset{}.{} does not exist. "
+ "Have you registered your symbolic function with "
+ "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?"
+ .format(ns, op_name, opset_version, op_name))
+ symbolic_fn = sym_registry.get_registered_op(symbolic_name, ns, opset_version)
+ attrs = {k: n[k] for k in n.attributeNames()}
+ return symbolic_fn(g, *inputs, **attrs)
+
else:
warnings.warn("ONNX export failed on an operator with unrecognized namespace {}::{}; "
- "please report a bug".format(ns, op_name))
+ "If you are trying to export a custom operator, make sure you registered "
+ "it with the right domain and version."
+ "Otherwise please report a bug".format(ns, op_name))
return None
except TypeError as e:
@@ -686,6 +704,22 @@
return getattr(self, sel)(k)
+def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+ if not bool(re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z]+[a-zA-Z0-9-_]*$", symbolic_name)):
+ raise RuntimeError("Failed to register operator {}. \
+ The symbolic name must match the format Domain::Name, \
+ and sould start with a letter and contain only \
+ alphanumerical characters"
+ .format(symbolic_name))
+ ns, op_name = symbolic_name.split('::')
+ unaccepted_domain_names = ["onnx", "aten", "prim"]
+ if ns in unaccepted_domain_names:
+ raise RuntimeError("Failed to register operator {}. The domain {} is already a used domain."
+ .format(symbolic_name, ns))
+ import torch.onnx.symbolic_registry as sym_registry
+ sym_registry.register_op(op_name, symbolic_fn, ns, opset_version)
+
+
torch._C.Graph.op = _graph_op
torch._C.Graph.at = _graph_at
torch._C.Graph.constant = _graph_constant