Support Exports to Multiple ONNX Opset (#19294)

Summary:
Support exporting multiple ONNX opsets (more specifically opset 10 for now), following the proposal in https://gist.github.com/spandantiwari/99700e60919c43bd167838038d20f353.
And add support for custom ops (merge with https://github.com/pytorch/pytorch/pull/18297).

This PR will be followed by another PR containing the changes related to testing the ops for different opsets.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19294

Reviewed By: zrphercule

Differential Revision: D15043951

Pulled By: houseroad

fbshipit-source-id: d336fc35b8827145639137bc348ae07e3c14bb1c
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index d0f2a3c..9939dbf 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -204,7 +204,7 @@
 // 'onnx' symbols correspond to ONNX operators.  Their semantics
 // are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md
 // The particular version we are targeting is specified by '_onnx_opset_version'
-// in torch.onnx.symbolic
+// in torch.onnx.symbolic_helper
 //
 // In general, most ONNX operators won't get an entry here, because they
 // are handled from the Python end.  However, you may occasionally need
diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst
index 90bf287..ef1dcad 100644
--- a/docs/source/onnx.rst
+++ b/docs/source/onnx.rst
@@ -212,10 +212,10 @@
 If the operator is an ATen operator, which means you can find the declaration
 of the function in ``torch/csrc/autograd/generated/VariableType.h``
 (available in generated code in PyTorch install dir), you should add the symbolic
-function in ``torch/onnx/symbolic.py`` and follow the instructions listed as below:
+function in ``torch/onnx/symbolic_opset<version>.py`` and follow the instructions listed as below:
 
-* Define the symbolic function in
-  `torch/onnx/symbolic.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic.py>`_.
+* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
+  `torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
   Make sure the function has the same name as the ATen operator/function
   defined in ``VariableType.h``.
 * The first parameter is always the exported ONNX graph.
@@ -303,7 +303,7 @@
 Here is an example of handling missing symbolic function for ``elu`` operator.
 We try to export the model and see the error message as below::
 
-    UserWarning: ONNX export failed on elu because torch.onnx.symbolic.elu does not exist
+    UserWarning: ONNX export failed on elu because torch.onnx.symbolic_opset9.elu does not exist
     RuntimeError: ONNX export failed: Couldn't export operator elu
 
 The export fails because PyTorch does not support exporting ``elu`` operator.
@@ -311,7 +311,7 @@
 in ``VariableType.h``. This means ``elu`` is an ATen operator.
 We check the `ONNX operator list <http://https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
 and confirm that ``Elu`` is standardized in ONNX.
-We add the following lines to ``symbolic.py``::
+We add the following lines to ``symbolic_opset9.py``::
 
     def elu(g, input, alpha, inplace=False):
         return g.op("Elu", input, alpha_f=_scalar(alpha))
@@ -319,7 +319,7 @@
 Now PyTorch is able to export ``elu`` operator.
 
 There are more examples in
-`symbolic.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic.py>`_,
+`symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_,
 `tensor.py <https://github.com/pytorch/pytorch/blob/99037d627da68cdf53d3d0315deceddfadf03bba/torch/autograd/_functions/tensor.py#L24>`_,
 `padding.py <https://github.com/pytorch/pytorch/blob/99037d627da68cdf53d3d0315deceddfadf03bba/torch/nn/_functions/padding.py#L8>`_.
 
diff --git a/test/onnx/expect/TestOperators.test_maxpool_dilations.expect b/test/onnx/expect/TestOperators.test_maxpool_dilations.expect
index 61b6d01..470a5d9 100644
--- a/test/onnx/expect/TestOperators.test_maxpool_dilations.expect
+++ b/test/onnx/expect/TestOperators.test_maxpool_dilations.expect
@@ -7,6 +7,11 @@
     output: "1"
     op_type: "MaxPool"
     attribute {
+      name: "ceil_mode"
+      i: 0
+      type: INT
+    }
+    attribute {
       name: "dilations"
       ints: 2
       type: INTS
diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py
new file mode 100644
index 0000000..37a87af
--- /dev/null
+++ b/test/onnx/test_onnx_opset.py
@@ -0,0 +1,111 @@
+from test_pytorch_common import TestCase, run_tests
+
+import torch
+import torch.onnx
+from torch.nn import Module
+
+import onnx
+
+import io
+
+from torch.onnx.symbolic_helper import _export_onnx_opset_version
+from torch.onnx import ir_version, producer_name, producer_version
+
+
+def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_version):
+    # check_onnx_components
+    assert model.ir_version == ir_version and \
+        model.producer_name == producer_name and \
+        model.producer_version == producer_version and \
+        model.opset_import[0].version == opset_version
+
+    # check the schema with the onnx checker
+    onnx.checker.check_model(model)
+
+    # check target type and attributes 
+    graph = model.graph
+    # ops should contain an object for each node
+    # in graph.node, in the right order.
+    # At least the op_name should be specified,
+    # but the op's attributes can optionally be
+    # specified as well
+    assert len(ops) == len(graph.node)
+    for i in range(0, len(ops)):
+        assert graph.node[i].op_type == ops[i]['op_name']
+        if "attributes" in ops[i] :
+            attributes = ops[i]['attributes']
+            assert len(attributes) == len(graph.node[i].attribute)
+            for j in range(0, len(attributes)):
+                for attribute_field in attributes[j].keys():
+                    assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
+
+
+def check_onnx_opsets_operator(module, x, ops, opset_versions):
+    for opset_version in opset_versions:
+        f = io.BytesIO()
+        torch.onnx.export(module, x, f, opset_version=opset_version)
+        model = onnx.load(io.BytesIO(f.getvalue()))
+        check_onnx_opset_operator(model, ops[opset_version], opset_version)
+
+
+class TestONNXOpset(TestCase):
+
+    def test_opset_fallback(self):
+        class MyModule(Module):
+            def forward(self, x):
+                return torch.isnan(x)
+
+        ops = [{"op_name" : "IsNaN"},
+               {"op_name" : "Cast", "attributes" : [{"name" : "to", "i" : 2, "type" : 2}]}]
+        ops = {9 : ops, 10 : ops}
+        x = torch.tensor([1.0, float('nan'), 2.0])
+        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
+
+    def test_topk(self):
+        class MyModule(Module):
+            def forward(self, x):
+                return torch.topk(x, 3)
+
+        ops_9 = [{"op_name" : "TopK", "attributes" : [{"name" : "axis", "i" : -1, "type" : 2},
+                 {"name" : "k", "i" : 3, "type" : 2}]}]
+        ops_10 = [{"op_name" : "Constant", "attributes" : [{"name" : "value", "type" : 4}]},
+                  {"op_name" : "Unsqueeze", "attributes" : [{"name" : "axes", "ints" : [0], "type" : 7}]},
+                  {"op_name" : "TopK", "attributes" : [{"name" : "axis", "i" : -1, "type" : 2}]}]
+        ops = {9 : ops_9, 10 : ops_10}
+        x = torch.arange(1., 6., requires_grad=True)
+        check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
+
+    def test_maxpool(self):
+        module = torch.nn.MaxPool1d(2, stride=1)
+
+        ops_9 = [{"op_name" : "MaxPool",
+                  "attributes" :
+                  [{"name": "kernel_shape", "ints": [2], "type": 7},
+                   {"name": "pads", "ints": [0, 0], "type": 7},
+                   {"name": "strides", "ints": [1], "type": 7}]}]
+        ops_10 = [{"op_name" : "MaxPool",
+                   "attributes" :
+                   [{"name": "ceil_mode", "i": 0, "type": 2},
+                    {"name": "kernel_shape", "ints": [2], "type": 7},
+                    {"name": "pads", "ints": [0, 0], "type": 7},
+                    {"name": "strides", "ints": [1], "type": 7}]}]
+        ops = {9 : ops_9, 10 : ops_10}
+        x = torch.randn(20, 16, 50)
+        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
+
+        # add test with dilations
+        module = torch.nn.MaxPool1d(2, stride=1, dilation=2)
+
+        ops_10 = [{"op_name" : "MaxPool",
+                   "attributes" :
+                   [{"name": "ceil_mode", "i": 0, "type": 2},
+                    {"name": "dilations", "ints": [2], "type": 7},
+                    {"name": "kernel_shape", "ints": [2], "type": 7},
+                    {"name": "pads", "ints": [0, 0], "type": 7},
+                    {"name": "strides", "ints": [1], "type": 7}]}]
+        ops = {9 : ops_9, 10 : ops_10}
+        x = torch.randn(20, 16, 50)
+        check_onnx_opsets_operator(module, x, ops, opset_versions=[10])
+
+if __name__ == '__main__':
+    run_tests()
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 529b186..7133bfb 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -1574,6 +1574,14 @@
 
         self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3)
 
+    def test_topk(self):
+        class TopKModel(torch.nn.Module):
+            def forward(self, input):
+                return torch.topk(input, 3)
+        model = TopKModel()
+        x = torch.arange(1., 6.)
+        self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
     def test_floor(self):
         class FloorModel(torch.nn.Module):
             def forward(self, input):
diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
index 4d3ee99..0470669 100644
--- a/test/onnx/test_utility_funs.py
+++ b/test/onnx/test_utility_funs.py
@@ -4,7 +4,7 @@
 import torch
 import torch.onnx
 from torch.onnx import utils
-from torch.onnx.symbolic import _set_opset_version
+from torch.onnx.symbolic_helper import _set_opset_version
 
 import onnx
 
diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp
index 6cf6794..83e9933 100644
--- a/torch/csrc/jit/passes/onnx.cpp
+++ b/torch/csrc/jit/passes/onnx.cpp
@@ -171,7 +171,8 @@
   torch::autograd::SymbolicContext ctx{};
   ctx.block = new_block;
   py::object onnx = py::module::import("torch.onnx");
-  py::object onnx_symbolic = py::module::import("torch.onnx.symbolic");
+  py::object onnx_symbolic = py::module::import("torch.onnx.symbolic_helper");
+  py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
 
   // Returns a node that n maps to in the new graph
   auto envFn = [&env](Value* n) -> Value* {
@@ -295,7 +296,6 @@
     if (func) {
       pyobj = func->get();
     }
-
     if (!py::hasattr(pyobj, "symbolic")) {
       cloneNode(op);
       return;
@@ -331,6 +331,8 @@
     // Call the symbolic function
     // Use a little trampoline function so we can give good error messages
     // upon argument mismatch
+    py::object opset_version = onnx_symbolic.attr("_export_onnx_opset_version");
+    onnx_registry.attr("register_op")(op->name(), pyobj.attr("symbolic"), "", opset_version);
     py::object raw_output = onnx.attr("_run_symbolic_method")(
         op->name(), pyobj.attr("symbolic"), py_symbolic_args);
 
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 519a10d..4a7b2e5 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -6,6 +6,13 @@
 
 ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
 
+# TODO: Update these variables when there 
+# is a new ir_version and producer_version
+# and use these values in the exporter
+ir_version = 4
+producer_name = "pytorch"
+producer_version = "1.1"
+
 
 class ExportTypes:
     PROTOBUF_FILE = 1
@@ -58,3 +65,8 @@
 def is_in_onnx_export():
     from torch.onnx import utils
     return utils.is_in_onnx_export()
+
+
+def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+    from torch.onnx import utils
+    return utils.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
new file mode 100644
index 0000000..575043c
--- /dev/null
+++ b/torch/onnx/symbolic_helper.py
@@ -0,0 +1,293 @@
+import torch
+from torch._C import ListType
+import warnings
+
+import torch.onnx
+# This import monkey-patches graph manipulation methods on Graph, used for the
+# ONNX symbolics
+import torch.onnx.utils
+
+from functools import wraps
+
+
+# Note [Edit Symbolic Files]
+# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
+#
+# - These files is ONLY for ATen operators (e.g., operators that show up in the
+#   trace as aten::blah).  If you need to special case a primitive operator,
+#   look at _run_symbolic_function
+# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
+#   tensors are always first, then non-tensor arguments.
+# - Parameter names must *exactly* match the names in VariableType.cpp, because
+#   dispatch is done with keyword arguments.
+# - Looking for inplace ops?  They're detected by the trailing underscore, and
+#   transparently dispatched to their non inplace versions in
+#   'run_symbolic_function'.   See Note [Export inplace]
+#
+# ----------------------------------------------------------------------------------
+# A note on Tensor types
+# ----------------------------------------------------------------------------------
+#
+# In general, we should avoid depending on the type of Tensor Values contained
+# within the trace graph. However, this is sometimes unavoidable (due to ONNX
+# spec requirements, etc). If you are implementing a symbolic and need Tensor
+# type information, note that there are several levels of Tensor types, defined
+# in aten/src/ATen/core/jit_type.h:
+#
+# TensorType - This is a Tensor, but we don't know anything about its
+#               properties (e.g. scalar type, # dims, shapes).
+#               Appears as `Tensor` in graph print-outs.
+# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
+#                             type and number of dimensions, but not the concrete
+#                             shapes. For example, appears as 'Float(*, *)' in
+#                             graph print-outs. Useful accessor methods include
+#                             dim() and scalarType()
+# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
+#                                               concrete sizes in addition to the information
+#                                               contained in TensorTyper. This adds a sizes()
+#                                               method which can be used to retrieve the
+#                                               concrete sizes.
+#
+# In general, we should prefer to rely on the least specific information possible.
+# For example, not relying on tensor properties at all is better than relying
+# on the number of dimensions (DimensionedTensorType) which is better than relying on
+# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
+# more robust to different graphs.
+
+# ---------------------------------------------------------------------------------
+# Helper functions
+# ---------------------------------------------------------------------------------
+
+# Save some builtins as locals, because we'll shadown them below
+_sum = sum
+
+
+def _parse_arg(value, desc):
+    if desc == 'none':
+        return value
+    if desc == 'v' or not _is_value(value):
+        return value
+    if value.node().kind() == 'onnx::Constant':
+        tval = value.node()['value']
+        if desc == 'i':
+            return int(tval)
+        elif desc == 'f':
+            return float(tval)
+        elif desc == 'b':
+            return bool(tval)
+        elif desc == 't':
+            return tval
+        elif desc == 'is':
+            return [int(v) for v in tval]
+        else:
+            raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node")
+    elif value.node().kind() == 'prim::ListConstruct':
+        if desc == 'is':
+            for v in value.node().inputs():
+                if v.node().kind() != 'onnx::Constant':
+                    raise RuntimeError("Failed to export an ONNX attribute, "
+                                       "since it's not constant, please try to make "
+                                       "things (e.g., kernel size) static if possible")
+            return [int(v.node()['value']) for v in value.node().inputs()]
+        else:
+            raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node")
+
+    raise RuntimeError("Unexpected node type: {}".format(value.node().kind()))
+
+
+def _maybe_get_const(value, desc):
+    if _is_value(value) and value.node().kind() == 'onnx::Constant':
+        return _parse_arg(value, desc)
+    return value
+
+
+def _maybe_get_scalar(value):
+    value_t = _maybe_get_const(value, 't')
+    if isinstance(value_t, torch.Tensor) and value_t.shape == ():
+        return value_t
+    return value
+
+
+def _get_const(value, desc, arg_name):
+    if _is_value(value) and value.node().kind() != 'onnx::Constant':
+        raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
+    return _parse_arg(value, desc)
+
+
+def _unpack_list(list_value):
+    list_node = list_value.node()
+    assert list_node.kind() == "prim::ListConstruct"
+    return list(list_node.inputs())
+
+
+def parse_args(*arg_descriptors):
+    def decorator(fn):
+        def wrapper(g, *args):
+            # some args may be optional, so the length may be smaller
+            assert len(arg_descriptors) >= len(args)
+            args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
+            return fn(g, *args)
+        # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
+        try:
+            wrapper = wraps(fn)(wrapper)
+        except Exception:
+            pass
+        return wrapper
+    return decorator
+
+
+def _scalar(x):
+    """Convert a scalar tensor into a Python value."""
+    assert x.numel() == 1
+    return x.item()
+
+
+def _if_scalar_type_as(g, self, tensor):
+    """
+    Convert self into the same type of tensor, as necessary.
+
+    We only support implicit casting for scalars, so we never
+    actually need to insert an ONNX cast operator here; just
+    fix up the scalar.
+    """
+    if isinstance(self, torch._C.Value):
+        return self
+    elif tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType":
+        ty = tensor.type().scalarType().lower()
+        return getattr(self, ty)()
+    else:
+        return self
+
+
+def _is_value(x):
+    return isinstance(x, torch._C.Value)
+
+
+def _is_tensor_list(x):
+    return x.type().isSubtypeOf(ListType.ofTensors())
+
+
+def _unimplemented(op, msg):
+    warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
+
+
+def _black_list_in_opset(name):
+    def symbolic_fn(*args, **kwargs):
+        warnings.warn("ONNX export failed on {}, which is not yet implemented for opset 10. "
+                      "Try exporting with a previous opset version."
+                      .format(name))
+    return symbolic_fn
+
+
+def _try_get_scalar_type(*args):
+    for arg in args:
+        try:
+            return arg.type().scalarType()
+        except RuntimeError:
+            pass
+    return None
+
+
+# ---------------------------------------------------------------------
+# ONNX operator version
+# ---------------------------------------------------------------------
+
+# READ ME BEFORE EDITING _default_onnx_opset_version:
+#
+# The variable below controls which ONNX operator set version we are
+# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
+# change occurred in version 8. As long as this variable < 8, you can
+# export models targeting the old behavior. However, if you bump
+# this variable to 8 or later, the breaking change will take into effect:
+# you MUST adjust any symbolic affected by breaking changes. The ONNX
+# spec publishes a *comprehensive* list of BC-breaking changes for every
+# operator revision at:
+#
+#   https://github.com/onnx/onnx/blob/master/docs/Changelog.md
+#
+# Please be sure to go through and check all of our implementations here before
+# increasing this number. This includes symbolic definitions NOT in this
+# file, so grep for "OpName" (with quotes)
+#
+# Besides, opset_version can be specified in the invocation of export()
+# and export_to_pretty_string(), and _export_onnx_opset_version will be set
+# and the symbolic functions should check it to determine the behavior
+# of the exporter.
+
+
+_default_onnx_opset_version = 9
+_onnx_master_opset = 10
+_onnx_stable_opsets = [9, 10]
+_export_onnx_opset_version = _default_onnx_opset_version
+
+
+def _set_opset_version(opset_version):
+    global _export_onnx_opset_version
+    if opset_version == _default_onnx_opset_version:
+        _export_onnx_opset_version = opset_version
+        return
+    if opset_version in _onnx_stable_opsets + [_onnx_master_opset]:
+        _export_onnx_opset_version = opset_version
+        return
+    raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
+
+
+# Metaprogram symbolics for each ATen native specialized cast operator.
+# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
+# ONNX cast node with `to` attribute 'UINT8'
+#
+# TODO: remove these once we support Type's in the JIT IR and we can once again
+# use the unified toType operator
+cast_pytorch_to_onnx = {
+    'Byte': torch.onnx.TensorProtoDataType.UINT8,
+    'Char': torch.onnx.TensorProtoDataType.INT8,
+    'Double': torch.onnx.TensorProtoDataType.DOUBLE,
+    'Float': torch.onnx.TensorProtoDataType.FLOAT,
+    'Half': torch.onnx.TensorProtoDataType.FLOAT16,
+    'Int': torch.onnx.TensorProtoDataType.INT32,
+    'Long': torch.onnx.TensorProtoDataType.INT64,
+    'Short': torch.onnx.TensorProtoDataType.INT16,
+}
+
+
+scalar_name_to_pytorch = {
+    'uint8_t': 'Byte',
+    'int8_t': 'Char',
+    'double': 'Double',
+    'float': 'Float',
+    'half': 'Half',
+    'int': 'Int',
+    'int64_t': 'Long',
+    'int16_t': 'Short',
+}
+
+
+# This indicates each scalar type's corresponding
+# torch type. Related source:
+# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
+scalar_type_to_pytorch_type = [
+    torch.uint8,    # 0
+    torch.int8,     # 1
+    torch.short,    # 2
+    torch.int,      # 3
+    torch.int64,    # 4
+    torch.half,     # 5
+    torch.float,    # 6
+    torch.double,   # 7
+]
+
+
+def _cast_func_template(to_i, g, input, non_blocking):
+    return g.op("Cast", input, to_i=to_i)
+
+
+scalar_type_to_onnx = [
+    cast_pytorch_to_onnx["Byte"],
+    cast_pytorch_to_onnx["Char"],
+    cast_pytorch_to_onnx["Short"],
+    cast_pytorch_to_onnx["Int"],
+    cast_pytorch_to_onnx["Long"],
+    cast_pytorch_to_onnx["Half"],
+    cast_pytorch_to_onnx["Float"],
+    cast_pytorch_to_onnx["Double"],
+]
diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py
new file mode 100644
index 0000000..ee75d7e
--- /dev/null
+++ b/torch/onnx/symbolic_opset10.py
@@ -0,0 +1,122 @@
+import torch
+from torch.nn.modules.utils import _single, _pair, _triple
+import torch.onnx
+# This import monkey-patches graph manipulation methods on Graph, used for the
+# ONNX symbolics
+import torch.onnx.utils
+
+from torch.onnx.symbolic_helper import parse_args, _unimplemented, _black_list_in_opset
+import torch.onnx.symbolic_opset9
+
+
+# EDITING THIS FILE? READ THIS FIRST!
+# see Note [Edit Symbolic Files] in symbolic_helper.py
+
+# This file exports ONNX ops for opset 10
+# Opset 10 is supported by ONNX release 1.5.0
+# release on 04/24/19
+
+
+# Blacklist operators for this opset version.
+# These operators have been updated in ONNX but not re-implemented here.
+# It is very important to blacklist these operators to avoid exporting
+# models with mixed versions of operators.
+# TODO : add support for the blacklisted operators in black_listed_operators
+black_listed_operators = ["flip",
+                          "slice",
+                          "upsample_nearest2d", "upsample_bilinear2d",
+                          "dropout", "feature_dropout", "alpha_dropout", "feature_alpha_dropout",
+                          "dropout_", "feature_dropout_", "alpha_dropout_", "feature_alpha_dropout_"]
+
+for black_listed_op in black_listed_operators:
+    vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
+
+
+# Add new operator here
+@parse_args('v', 'i', 'i', 'i', 'i')
+def topk(g, self, k, dim, largest, sorted, out=None):
+    if out is not None:
+        _unimplemented("TopK", "Out parameter is not supported for topk")
+    if not largest:
+        _unimplemented("TopK", "Ascending TopK is not supported")
+    k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64))
+    from torch.onnx.symbolic_opset9 import unsqueeze
+    k = unsqueeze(g, k, 0)
+    return g.op("TopK", self, k, axis_i=dim, outputs=2)
+
+
+def _max_pool(name, tuple_fn, ndims, return_indices):
+    @parse_args('v', 'is', 'is', 'is', 'is', 'i')
+    def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
+        if not stride:
+            stride = kernel_size
+        kwargs = {
+            'kernel_shape_i': tuple_fn(kernel_size),
+            'pads_i': tuple_fn(padding) * 2,
+            'strides_i': tuple_fn(stride),
+            'ceil_mode_i': ceil_mode,
+        }
+        if set(tuple_fn(dilation)) != {1}:
+            kwargs['dilations_i'] = tuple_fn(dilation)
+        # easy but hacky way to get flattened indices values
+        # to be used to convert the indices values to non-flattened.
+        # In ONNX the indices are computed as a flatten 1-D tensor,
+        # so the values in indices are in [0, N x C x D1 x ... x Dn).
+        # To convert the indices to the same format used by Pytorch,
+        # we first execute a maxpool with a kernel and stride of 1 on the same input.
+        # This will result in a tensor of indices in which each index will have it's own value.
+        # Using this tensor as a reference, we extract the first index of each axis and substract
+        # it from each index of this axis in the indices to convert.
+        # This step will result in a tensor were each dimension has values of indices within
+        # the dimension it is in.
+        # For more information :
+        # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
+        if return_indices:
+            r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
+            _, flattened_indices = g.op("MaxPool", input, outputs=2,
+                                        kernel_shape_i=[1 for _ in range(ndims)],
+                                        strides_i=[1 for _ in range(ndims)])
+            # convert indices to have non-flattened indices values
+            s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
+                          starts=tuple_fn(0), ends=tuple_fn(1))
+            indices = sub(g, indices, s)
+            return r, indices
+        else:
+            r = g.op("MaxPool", input, outputs=1, **kwargs)
+            return r
+
+    return symbolic_fn
+
+
+max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
+max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
+max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
+max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
+max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
+max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
+
+
+def _avg_pool(name, tuple_fn):
+    @parse_args('v', 'is', 'is', 'is', 'i', 'i')
+    def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
+        if not stride:
+            stride = kernel_size
+        padding = tuple(tuple_fn(padding))
+        if count_include_pad:
+            input = g.op("Pad", input,
+                         pads_i=((0,) * 2 + padding) * 2,
+                         mode_s='constant',
+                         value_f=0.)
+            padding = (0,) * len(padding)
+        output = g.op("AveragePool", input,
+                      kernel_shape_i=tuple_fn(kernel_size),
+                      strides_i=tuple_fn(stride),
+                      pads_i=padding * 2,
+                      ceil_mode_i=ceil_mode)
+        return output
+    return symbolic_fn
+
+
+avg_pool1d = _avg_pool('avg_pool1d', _single)
+avg_pool2d = _avg_pool('avg_pool2d', _pair)
+avg_pool3d = _avg_pool('avg_pool3d', _triple)
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic_opset9.py
similarity index 77%
rename from torch/onnx/symbolic.py
rename to torch/onnx/symbolic_opset9.py
index 6c0be5c..3a17b95 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1,233 +1,28 @@
 import torch
 from torch._C import ListType, OptionalType
 from torch.nn.modules.utils import _single, _pair, _triple
-import warnings
 
 import torch.onnx
 # This import monkey-patches graph manipulation methods on Graph, used for the
 # ONNX symbolics
 import torch.onnx.utils
 
-from functools import partial, wraps
+from functools import partial
+
+import torch.onnx.symbolic_helper as sym_help
+from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
 
 import numpy
 import math
+import warnings
+
 
 # EDITING THIS FILE? READ THIS FIRST!
-#
-# - This file is ONLY for ATen operators (e.g., operators that show up in the
-#   trace as aten::blah).  If you need to special case a primitive operator,
-#   look at _run_symbolic_function
-# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
-#   tensors are always first, then non-tensor arguments.
-# - Parameter names must *exactly* match the names in VariableType.cpp, because
-#   dispatch is done with keyword arguments.
-# - Looking for inplace ops?  They're detected by the trailing underscore, and
-#   transparently dispatched to their non inplace versions in
-#   'run_symbolic_function'.   See Note [Export inplace]
-#
-# ----------------------------------------------------------------------------------
-# A note on Tensor types
-# ----------------------------------------------------------------------------------
-#
-# In general, we should avoid depending on the type of Tensor Values contained
-# within the trace graph. However, this is sometimes unavoidable (due to ONNX
-# spec requirements, etc). If you are implementing a symbolic and need Tensor
-# type information, note that there are several levels of Tensor types, defined
-# in aten/src/ATen/core/jit_type.h:
-#
-# TensorType - This is a Tensor, but we don't know anything about its
-#               properties (e.g. scalar type, # dims, shapes).
-#               Appears as `Tensor` in graph print-outs.
-# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
-#                             type and number of dimensions, but not the concrete
-#                             shapes. For example, appears as 'Float(*, *)' in
-#                             graph print-outs. Useful accessor methods include
-#                             dim() and scalarType()
-# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
-#                                               concrete sizes in addition to the information
-#                                               contained in TensorTyper. This adds a sizes()
-#                                               method which can be used to retrieve the
-#                                               concrete sizes.
-#
-# In general, we should prefer to rely on the least specific information possible.
-# For example, not relying on tensor properties at all is better than relying
-# on the number of dimensions (DimensionedTensorType) which is better than relying on
-# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
-# more robust to different graphs.
+# see Note [Edit Symbolic Files] in symbolic_helper.py
 
-# ---------------------------------------------------------------------------------
-# Helper functions
-# ---------------------------------------------------------------------------------
-
-# Save some builtins as locals, because we'll shadown them below
-_sum = sum
-
-
-def _parse_arg(value, desc):
-    if desc == 'none':
-        return value
-    if desc == 'v' or not _is_value(value):
-        return value
-    if value.node().kind() == 'onnx::Constant':
-        tval = value.node()['value']
-        if desc == 'i':
-            return int(tval)
-        elif desc == 'f':
-            return float(tval)
-        elif desc == 'b':
-            return bool(tval)
-        elif desc == 't':
-            return tval
-        elif desc == 'is':
-            return [int(v) for v in tval]
-        else:
-            raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node")
-    elif value.node().kind() == 'prim::ListConstruct':
-        if desc == 'is':
-            for v in value.node().inputs():
-                if v.node().kind() != 'onnx::Constant':
-                    raise RuntimeError("Failed to export an ONNX attribute, "
-                                       "since it's not constant, please try to make "
-                                       "things (e.g., kernel size) static if possible")
-            return [int(v.node()['value']) for v in value.node().inputs()]
-        else:
-            raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node")
-
-    raise RuntimeError("Unexpected node type: {}".format(value.node().kind()))
-
-
-def _maybe_get_const(value, desc):
-    if _is_value(value) and value.node().kind() == 'onnx::Constant':
-        return _parse_arg(value, desc)
-    return value
-
-
-def _maybe_get_scalar(value):
-    value_t = _maybe_get_const(value, 't')
-    if isinstance(value_t, torch.Tensor) and value_t.shape == ():
-        return value_t
-    return value
-
-
-def _get_const(value, desc, arg_name):
-    if _is_value(value) and value.node().kind() != 'onnx::Constant':
-        raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
-    return _parse_arg(value, desc)
-
-
-def _unpack_list(list_value):
-    list_node = list_value.node()
-    assert list_node.kind() == "prim::ListConstruct"
-    return list(list_node.inputs())
-
-
-def parse_args(*arg_descriptors):
-    def decorator(fn):
-        def wrapper(g, *args):
-            # some args may be optional, so the length may be smaller
-            assert len(arg_descriptors) >= len(args)
-            args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
-            return fn(g, *args)
-        # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
-        try:
-            wrapper = wraps(fn)(wrapper)
-        except Exception:
-            pass
-        return wrapper
-    return decorator
-
-
-def _scalar(x):
-    """Convert a scalar tensor into a Python value."""
-    assert x.numel() == 1
-    return x.item()
-
-
-def _if_scalar_type_as(g, self, tensor):
-    """
-    Convert self into the same type of tensor, as necessary.
-
-    We only support implicit casting for scalars, so we never
-    actually need to insert an ONNX cast operator here; just
-    fix up the scalar.
-    """
-    if isinstance(self, torch._C.Value):
-        return self
-    elif tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType":
-        ty = tensor.type().scalarType().lower()
-        return getattr(self, ty)()
-    else:
-        return self
-
-
-def _is_value(x):
-    return isinstance(x, torch._C.Value)
-
-
-def _is_tensor_list(x):
-    return x.type().isSubtypeOf(ListType.ofTensors())
-
-
-def _unimplemented(op, msg):
-    warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
-
-
-def _try_get_scalar_type(*args):
-    for arg in args:
-        try:
-            return arg.type().scalarType()
-        except RuntimeError:
-            pass
-    return None
-
-
-# ---------------------------------------------------------------------
-# ONNX operator version
-# ---------------------------------------------------------------------
-
-# READ ME BEFORE EDITING _default_onnx_opset_version:
-#
-# The variable below controls which ONNX operator set version we are
-# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
-# change occurred in version 8. As long as this variable < 8, you can
-# export models targeting the old behavior. However, if you bump
-# this variable to 8 or later, the breaking change will take into effect:
-# you MUST adjust any symbolic affected by breaking changes. The ONNX
-# spec publishes a *comprehensive* list of BC-breaking changes for every
-# operator revision at:
-#
-#   https://github.com/onnx/onnx/blob/master/docs/Changelog.md
-#
-# Please be sure to go through and check all of our implementations here before
-# increasing this number. This includes symbolic definitions NOT in this
-# file, so grep for "OpName" (with quotes)
-#
-# Besides, opset_version can be specified in the invocation of export()
-# and export_to_pretty_string(), and _export_onnx_opset_version will be set
-# and the symbolic functions should check it to determine the behavior
-# of the exporter.
-
-
-_default_onnx_opset_version = 9
-_onnx_master_opset = 10
-_onnx_stable_opsets = [9]
-_export_onnx_opset_version = _default_onnx_opset_version
-
-
-def _set_opset_version(opset_version):
-    global _export_onnx_opset_version
-    if opset_version == _export_onnx_opset_version:
-        return
-    if opset_version in _onnx_stable_opsets + [_onnx_master_opset]:
-        _export_onnx_opset_version = opset_version
-        return
-    raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
-
-
-# ---------------------------------------------------------------------
-# Symbolic definitions
-# ---------------------------------------------------------------------
+# This file exports ONNX ops for opset 9
+# Opset 9 is supported by ONNX release 1.4.1
+# release on 01/23/19 
 
 
 # Note [Pointwise by scalar]
@@ -284,60 +79,60 @@
 
 def add(g, self, other, alpha=None):
     # default alpha arg is to allow no-alpha add (aten add st overload no alpha)
-    if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
+    if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
         return _unimplemented("add", "alpha != 1")
     # See Note [Pointwise by scalar]
-    other = _maybe_get_scalar(other)
-    return g.op("Add", self, _if_scalar_type_as(g, other, self))
+    other = sym_help._maybe_get_scalar(other)
+    return g.op("Add", self, sym_help._if_scalar_type_as(g, other, self))
 
 
 def sub(g, self, other, alpha=None):
     # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
-    if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
+    if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
         return _unimplemented("sub", "alpha != 1")
     # See Note [Pointwise by scalar]. Note that self or other may be scalars.
-    other = _maybe_get_scalar(other)
-    return g.op("Sub", self, _if_scalar_type_as(g, other, self))
+    other = sym_help._maybe_get_scalar(other)
+    return g.op("Sub", self, sym_help._if_scalar_type_as(g, other, self))
 
 
 def rsub(g, self, other, alpha=None):
-    other = _maybe_get_scalar(other)
-    other = _if_scalar_type_as(g, other, self)
+    other = sym_help._maybe_get_scalar(other)
+    other = sym_help._if_scalar_type_as(g, other, self)
     return sub(g, other, self, alpha=alpha)
 
 
 def mul(g, self, other):
     # See Note [Pointwise by scalar]
-    other = _maybe_get_scalar(other)
-    return g.op("Mul", self, _if_scalar_type_as(g, other, self))
+    other = sym_help._maybe_get_scalar(other)
+    return g.op("Mul", self, sym_help._if_scalar_type_as(g, other, self))
 
 
 def div(g, self, other):
     # See Note [Pointwise by scalar]
-    other = _maybe_get_scalar(other)
-    return g.op("Div", self, _if_scalar_type_as(g, other, self))
+    other = sym_help._maybe_get_scalar(other)
+    return g.op("Div", self, sym_help._if_scalar_type_as(g, other, self))
 
 
 def reciprocal(g, self):
-    return g.op("Div", _if_scalar_type_as(g, torch.ones(1), self), self)
+    return g.op("Div", sym_help._if_scalar_type_as(g, torch.ones(1), self), self)
 
 
 @parse_args('v', 'i')
 def cat(g, tensor_list, dim):
-    tensors = _unpack_list(tensor_list)
+    tensors = sym_help._unpack_list(tensor_list)
     return g.op("Concat", *tensors, axis_i=dim)
 
 
 @parse_args('v', 'i')
 def stack(g, tensor_list, dim):
-    unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in _unpack_list(tensor_list)]
+    unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in sym_help._unpack_list(tensor_list)]
     return g.op("Concat", *unsqueezed, axis_i=dim)
 
 
 def mm(g, self, other):
     # Create a dummy C tensor. Only needed for API purposes, the value is
     # since beta = 0
-    ty = _try_get_scalar_type(self, other).lower()
+    ty = sym_help._try_get_scalar_type(self, other).lower()
     C = g.constant(0, [1], ty)
     return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
 
@@ -352,7 +147,7 @@
 
 @parse_args('v', 'v', 'v', 't', 't')
 def addmm(g, self, mat1, mat2, beta, alpha):
-    return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha))
+    return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
 
 
 def neg(g, self):
@@ -409,7 +204,7 @@
             return g.op(onnx_op_name, self, keepdims_i=0)
         else:
             # dim-reduce path
-            dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim')
+            dim, keepdim = sym_help._get_const(dim, 'i', 'dim'), sym_help._get_const(keepdim, 'i', 'keepdim')
             return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
     return symbolic
 
@@ -442,8 +237,8 @@
 
 
 def expand(g, self, size, implicit):
-    size = _maybe_get_const(size, 'is')
-    if not _is_value(size):
+    size = sym_help._maybe_get_const(size, 'is')
+    if not sym_help._is_value(size):
         size = g.op("Constant", value_t=torch.LongTensor(size))
     return g.op("Expand", self, size)
 
@@ -504,8 +299,8 @@
 
 
 def view(g, self, size):
-    size = _maybe_get_const(size, 'is')
-    if _is_value(size):
+    size = sym_help._maybe_get_const(size, 'is')
+    if sym_help._is_value(size):
         shape = size
     else:
         if self.isCompleteTensor():
@@ -570,7 +365,7 @@
             if size == 1:
                 dims.append(i)
     else:
-        dims = [_get_const(dim, 'i', 'dim')]
+        dims = [sym_help._get_const(dim, 'i', 'dim')]
         # Handle negative dims
         for i, dim in enumerate(dims):
             if dim < 0:
@@ -607,18 +402,18 @@
 @parse_args('v', 't', 't')
 def threshold(g, self, threshold, value):
     # See Note [Export inplace]
-    if _scalar(threshold) != 0:
+    if sym_help._scalar(threshold) != 0:
         return _unimplemented("threshold", "non-zero threshold")
-    if _scalar(value) != 0:
+    if sym_help._scalar(value) != 0:
         return _unimplemented("threshold", "non-zero value")
     return g.op("Relu", self)
 
 
 def leaky_relu(g, input, negative_slope, inplace=False):
-    negative_slope = _get_const(negative_slope, 't', 'negative_slope')
+    negative_slope = sym_help._get_const(negative_slope, 't', 'negative_slope')
     # See Note [Export inplace]
     # TODO: Talk to ONNX about unconditional cast of scalar to float
-    return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope))
+    return g.op("LeakyRelu", input, alpha_f=sym_help._scalar(negative_slope))
 
 
 @parse_args('v', 'i')
@@ -655,13 +450,13 @@
         if input.type().dim() == dim + 1:
             softmax = g.op('Softmax', input, axis_i=dim)
             if dtype:
-                softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+                softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[dtype])
             return softmax
     exp = g.op('Exp', input)
     sum = g.op('ReduceSum', exp, axes_i=[dim])
     softmax = g.op('Div', exp, sum)
     if dtype:
-        softmax = g.op("Cast", softmax, to_i=scalar_type_to_onnx[dtype])
+        softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[dtype])
     return softmax
 
 
@@ -700,6 +495,8 @@
     def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
         if ceil_mode and input.type().kind() != "CompleteTensorType":
             return _unimplemented(name, "input size not accesible")
+        if set(tuple_fn(dilation)) != {1}:
+            return _unimplemented(name, "dilation")
         if not stride:
             stride = kernel_size
         padding = tuple(tuple_fn(padding))
@@ -713,8 +510,6 @@
             'pads_i': padding,
             'strides_i': tuple_fn(stride),
         }
-        if set(tuple_fn(dilation)) != {1}:
-            kwargs['dilations_i'] = tuple_fn(dilation)
         # easy but hacky way to get flattened indices values
         # to be used to convert the indices values to non-flattened.
         # In ONNX the indices are computed as a flatten 1-D tensor,
@@ -888,7 +683,7 @@
 
 def wrap_logical_op_with_cast_to_uint8(func):
     def wrap_with_cast(g, input, other):
-        return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
+        return g.op("Cast", func(g, input, other), to_i=sym_help.cast_pytorch_to_onnx['Byte'])
     return wrap_with_cast
 
 
@@ -915,8 +710,8 @@
 
 
 def gt_impl(g, input, other):
-    other = _maybe_get_scalar(other)
-    return g.op("Greater", input, _if_scalar_type_as(g, other, input))
+    other = sym_help._maybe_get_scalar(other)
+    return g.op("Greater", input, sym_help._if_scalar_type_as(g, other, input))
 
 
 @wrap_logical_op_with_cast_to_uint8
@@ -925,22 +720,22 @@
 
 
 def lt_impl(g, input, other):
-    other = _maybe_get_scalar(other)
-    return g.op("Less", input, _if_scalar_type_as(g, other, input))
+    other = sym_help._maybe_get_scalar(other)
+    return g.op("Less", input, sym_help._if_scalar_type_as(g, other, input))
 
 
 @wrap_logical_op_with_cast_to_uint8
 @wrap_logical_op_with_negation
 def ge(g, input, other):
-    other = _maybe_get_scalar(other)
-    return lt_impl(g, input, _if_scalar_type_as(g, other, input))
+    other = sym_help._maybe_get_scalar(other)
+    return lt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))
 
 
 @wrap_logical_op_with_cast_to_uint8
 @wrap_logical_op_with_negation
 def le(g, input, other):
-    other = _maybe_get_scalar(other)
-    return gt_impl(g, input, _if_scalar_type_as(g, other, input))
+    other = sym_help._maybe_get_scalar(other)
+    return gt_impl(g, input, sym_help._if_scalar_type_as(g, other, input))
 
 
 def where(g, condition, self, other):
@@ -957,7 +752,7 @@
         return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
     return_op = g.op("LogSoftmax", input, axis_i=dim)
     if dtype:
-        return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
+        return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[dtype])
     return return_op
 
 
@@ -1064,7 +859,7 @@
     if input_scale and input_scale != 1.:
         return _unimplemented("input_scale", "does not support input_scale in Elu")
     # See Note [Export inplace]
-    return g.op("Elu", input, alpha_f=_scalar(alpha))
+    return g.op("Elu", input, alpha_f=sym_help._scalar(alpha))
 
 
 def selu(g, input):
@@ -1077,7 +872,7 @@
 
 
 def index_put(g, self, indices_list_value, values, accumulate):
-    indices_list = _unpack_list(indices_list_value)
+    indices_list = sym_help._unpack_list(indices_list_value)
     args = [self] + indices_list + [values, accumulate]
     return g.op("ATen", *args, operator_s='index_put')
 
@@ -1088,7 +883,7 @@
 
     if other.isCompleteTensor():
         other_type_name = other.type().scalarType()
-        return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
+        return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other_type_name])
     else:
         # We don't know the type of other, bail by emitting ATen
         return g.op("ATen", self, other, operator_s="type_as")
@@ -1114,8 +909,8 @@
 
 
 def pow(g, self, exponent):
-    exponent = _maybe_get_scalar(exponent)
-    return g.op("Pow", self, _if_scalar_type_as(g, exponent, self))
+    exponent = sym_help._maybe_get_scalar(exponent)
+    return g.op("Pow", self, sym_help._if_scalar_type_as(g, exponent, self))
 
 
 def clamp(g, self, min, max):
@@ -1152,8 +947,8 @@
         return g.op("Max", self, dim_or_y)
     # torch.max(input, dim, keepdim)
     else:
-        dim = _get_const(dim_or_y, 'i', 'dim')
-        keepdim = _get_const(keepdim, 'i', 'keepdim')
+        dim = sym_help._get_const(dim_or_y, 'i', 'dim')
+        keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
         max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
         indices = g.op('ArgMax', self, axis_i=dim, keepdims_i=keepdim)
         return max, indices
@@ -1168,8 +963,8 @@
         return g.op("Min", self, dim_or_y)
     # torch.min(input, dim, keepdim)
     else:
-        dim = _get_const(dim_or_y, 'i', 'dim')
-        keepdim = _get_const(keepdim, 'i', 'keepdim')
+        dim = sym_help._get_const(dim_or_y, 'i', 'dim')
+        keepdim = sym_help._get_const(keepdim, 'i', 'keepdim')
         min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
         indices = g.op('ArgMin', self, axis_i=dim, keepdims_i=keepdim)
         return min, indices
@@ -1191,7 +986,6 @@
     @parse_args('v', 'f', 'i')
     def feature_dropout(g, input, p, train):
         # NB: In inference mode, FeatureDropout is exported as an identity op.
-        from torch.onnx.symbolic import _unimplemented
         if train:
             return _unimplemented(name, "training mode")
         return input
@@ -1238,69 +1032,9 @@
                 outputs=3)
 
 
-# Metaprogram symbolics for each ATen native specialized cast operator.
-# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
-# ONNX cast node with `to` attribute 'UINT8'
-#
-# TODO: remove these once we support Type's in the JIT IR and we can once again
-# use the unified toType operator
-cast_pytorch_to_onnx = {
-    'Byte': torch.onnx.TensorProtoDataType.UINT8,
-    'Char': torch.onnx.TensorProtoDataType.INT8,
-    'Double': torch.onnx.TensorProtoDataType.DOUBLE,
-    'Float': torch.onnx.TensorProtoDataType.FLOAT,
-    'Half': torch.onnx.TensorProtoDataType.FLOAT16,
-    'Int': torch.onnx.TensorProtoDataType.INT32,
-    'Long': torch.onnx.TensorProtoDataType.INT64,
-    'Short': torch.onnx.TensorProtoDataType.INT16,
-}
-
-scalar_name_to_pytorch = {
-    'uint8_t': 'Byte',
-    'int8_t': 'Char',
-    'double': 'Double',
-    'float': 'Float',
-    'half': 'Half',
-    'int': 'Int',
-    'int64_t': 'Long',
-    'int16_t': 'Short',
-}
-
-
-# This indicates each scalar type's corresponding
-# torch type. Related source:
-# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
-scalar_type_to_pytorch_type = [
-    torch.uint8,    # 0
-    torch.int8,     # 1
-    torch.short,    # 2
-    torch.int,      # 3
-    torch.int64,    # 4
-    torch.half,     # 5
-    torch.float,    # 6
-    torch.double,   # 7
-]
-
-
-def _cast_func_template(to_i, g, input, non_blocking):
-    return g.op("Cast", input, to_i=to_i)
-
-
-for k, v in cast_pytorch_to_onnx.items():
+for k, v in sym_help.cast_pytorch_to_onnx.items():
     name = '_cast_{}'.format(k)
-    globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v))
-
-
-scalar_type_to_onnx = [
-    cast_pytorch_to_onnx["Byte"],
-    cast_pytorch_to_onnx["Char"],
-    cast_pytorch_to_onnx["Short"],
-    cast_pytorch_to_onnx["Int"],
-    cast_pytorch_to_onnx["Long"],
-    cast_pytorch_to_onnx["Half"],
-    cast_pytorch_to_onnx["Float"],
-    cast_pytorch_to_onnx["Double"],
-]
+    globals()[name] = parse_args('v', 'i')(partial(sym_help._cast_func_template, v))
 
 
 @parse_args('v', 'i', 'v', 'v', 'b')
@@ -1309,7 +1043,7 @@
         raise RuntimeError("onnx pin_memory support is not implemented")
     # NOTE: no way to set device and layout in ONNX, so we ignore it
     return g.op("ConstantOfShape", sizes,
-                value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v', 'b')
@@ -1318,7 +1052,7 @@
         raise RuntimeError("onnx pin_memory support is not implemented")
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
-                value_t=torch.tensor([0], dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v', 'b')
@@ -1326,7 +1060,7 @@
     if pin_memory:
         raise RuntimeError("onnx pin_memory support is not implemented")
     return g.op("ConstantOfShape", sizes,
-                value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v', 'b')
@@ -1335,20 +1069,20 @@
         raise RuntimeError("onnx pin_memory support is not implemented")
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
-                value_t=torch.tensor([1], dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
 
 
 def full(g, sizes, value, dtype, layout, device, pin_memory=False):
     if pin_memory and _parse_arg(pin_memory, 'b'):
         raise RuntimeError("onnx pin_memory support is not implemented")
-    const_value = _maybe_get_const(value, 't')
-    if _is_value(const_value):
+    const_value = sym_help._maybe_get_const(value, 't')
+    if sym_help._is_value(const_value):
         tmp = zeros(sizes, dtype, layout, device)
         return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
     else:
-        dtype = _get_const(dtype, 'i', 'dtype')
+        dtype = sym_help._get_const(dtype, 'i', 'dtype')
         return g.op("ConstantOfShape", sizes,
-                    value_t=torch.tensor([const_value], dtype=scalar_type_to_pytorch_type[dtype]))
+                    value_t=torch.tensor([const_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'f', 'i', 'v', 'v', 'b')
@@ -1357,7 +1091,7 @@
         raise RuntimeError("onnx pin_memory support is not implemented")
     shape = g.op("Shape", input)
     return g.op("ConstantOfShape", shape,
-                value_t=torch.tensor([fill_value], dtype=scalar_type_to_pytorch_type[dtype]))
+                value_t=torch.tensor([fill_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'v', 'v', 'v', 'i')
@@ -1422,32 +1156,32 @@
             return self
         else:
             # aten::to(Tensor, ScalarType, bool, bool)
-            dtype = _get_const(args[0], 'i', 'dtype')
-            return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+            dtype = sym_help._get_const(args[0], 'i', 'dtype')
+            return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
     elif len(args) == 4:
         # aten::to(Tensor, Device, ScalarType, bool, bool)
-        dtype = _get_const(args[1], 'i', 'dtype')
-        return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+        dtype = sym_help._get_const(args[1], 'i', 'dtype')
+        return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
     elif len(args) == 5:
         # aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
-        dtype = _get_const(args[0], 'i', 'dtype')
+        dtype = sym_help._get_const(args[0], 'i', 'dtype')
         # Layout and device are ignored
-        return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+        return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
     elif len(args) == 6:
         # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool) -> Tensor
-        dtype = _get_const(args[0], 'i', 'dtype')
+        dtype = sym_help._get_const(args[0], 'i', 'dtype')
         # Layout and device are ignored
-        return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+        return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
     else:
         raise NotImplementedError("Unknown aten::to signature")
 
 
 def repeat(g, self, repeats):
-    if not _is_value(repeats):
+    if not sym_help._is_value(repeats):
         repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
-    const_repeats = _maybe_get_const(repeats, 'is')
+    const_repeats = sym_help._maybe_get_const(repeats, 'is')
 
-    if self.isCompleteTensor() and not _is_value(const_repeats):
+    if self.isCompleteTensor() and not sym_help._is_value(const_repeats):
         sizes = self.type().sizes()
         diff_dims = len(const_repeats) - len(sizes)
         if diff_dims > 0:
@@ -1600,20 +1334,20 @@
 
 @parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
 def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
-    hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
+    hidden, weight = sym_help._unpack_list(hidden_v), sym_help._unpack_list(weight_v)
     return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
                         dropout, train, bidirectional, batch_first)
 
 
 @parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
 def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
-    hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
+    hidden, weight = sym_help._unpack_list(hidden_v), sym_help._unpack_list(weight_v)
     return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
                         dropout, train, bidirectional, batch_sizes=batch_sizes)
 
 
 def lstm(g, *args):
-    if _is_tensor_list(args[3]):
+    if sym_help._is_tensor_list(args[3]):
         return _lstm_packed(g, *args)
     else:
         return _lstm_full(g, *args)
@@ -1622,18 +1356,18 @@
 def _one_hidden_rnn(kind):
     @parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
     def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
-        weight = _unpack_list(weight_v)
+        weight = sym_help._unpack_list(weight_v)
         return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
                             dropout, train, bidirectional, batch_first)
 
     @parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
     def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
-        weight = _unpack_list(weight_v)
+        weight = sym_help._unpack_list(weight_v)
         return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
                             dropout, train, bidirectional, batch_sizes=batch_sizes)
 
     def symbolic(g, *args):
-        if _is_tensor_list(args[3]):
+        if sym_help._is_tensor_list(args[3]):
             return _rnn_packed(g, *args)
         else:
             return _rnn_full(g, *args)
@@ -1692,7 +1426,7 @@
 
 def randn(g, *shapes):
     shapes_list = list(shapes)
-    shape = _maybe_get_const(shapes_list[0], "is")
+    shape = sym_help._maybe_get_const(shapes_list[0], "is")
     return g.op('RandomNormal', shape_i=shape)
 
 
@@ -1750,7 +1484,7 @@
 @parse_args('v')
 def isnan(g, input):
     output = g.op('IsNaN', input)
-    output = _cast_func_template(cast_pytorch_to_onnx['Byte'], g, output, None)
+    output = sym_help._cast_func_template(sym_help.cast_pytorch_to_onnx['Byte'], g, output, None)
     return output
 
 
diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py
new file mode 100644
index 0000000..4d4d272
--- /dev/null
+++ b/torch/onnx/symbolic_registry.py
@@ -0,0 +1,70 @@
+import warnings
+import importlib
+from inspect import getmembers, isfunction
+
+# The symbolic registry "_registry" is a dictionary that maps operators
+# (for a specific domain and opset version) to their symbolic functions.
+# An operator is defined by its domain, opset version, and opname.
+# The keys are tuples (domain, version), (where domain is a string, and version is an int),
+# and the operator's name (string).
+# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
+_registry = {}
+
+_symbolic_versions = {}
+from torch.onnx.symbolic_helper import _onnx_stable_opsets
+for opset_version in _onnx_stable_opsets:
+    module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
+    _symbolic_versions[opset_version] = module
+
+def register_version(domain, version):
+    if not is_registered_version(domain, version):
+        global _registry
+        _registry[(domain, version)] = {}
+    register_ops_in_version(domain, version)
+
+
+def register_ops_in_version(domain, version):
+    # iterates through the symbolic functions of
+    # the specified opset version, and the previous
+    # opset versions for operators supported in
+    # previous versions
+    iter_version = version
+    while iter_version >= 9:
+        version_ops = get_ops_in_version(iter_version)
+        for op in version_ops:
+            if isfunction(op[1]) and \
+               not is_registered_op(op[0], domain, version):
+                register_op(op[0], op[1], domain, version)
+        iter_version = iter_version - 1
+
+
+def get_ops_in_version(version):
+    return getmembers(_symbolic_versions[version])
+
+
+def is_registered_version(domain, version):
+    global _registry
+    return (domain, version) in _registry
+
+
+def register_op(opname, op, domain, version):
+    if domain is None or version is None:
+        warnings.warn("ONNX export failed. The ONNX domain and/or version to register are None.")
+    global _registry
+    if not is_registered_version(domain, version):
+        _registry[(domain, version)] = {}
+    _registry[(domain, version)][opname] = op
+
+
+def is_registered_op(opname, domain, version):
+    if domain is None or version is None:
+        warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
+    global _registry
+    return (domain, version) in _registry and opname in _registry[(domain, version)]
+
+
+def get_registered_op(opname, domain, version):
+    if domain is None or version is None:
+        warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
+    global _registry
+    return _registry[(domain, version)][opname]
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index b5f59d5..0a8b884 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -93,7 +93,7 @@
             output nodes of the graph, in order
         aten (bool, default False): [DEPRECATED. use operator_export_type] export the
             model in aten mode. If using aten mode, all the ops original exported
-            by the functions in symbolic.py are exported as ATen ops.
+            by the functions in symbolic_opset<version>.py are exported as ATen ops.
         export_raw_ir (bool, default False): [DEPRECATED. use operator_export_type]
             export the internal IR directly instead of converting it to ONNX ops.
         operator_export_type (enum, default OperatorExportTypes.ONNX):
@@ -107,7 +107,7 @@
             evolve before next stable release, by default we export to one stable
             opset version. Right now, supported stable opset version is 9.
             The opset_version must be _onnx_master_opset or in _onnx_stable_opsets
-            which are defined in torch/onnx/symbolic.py
+            which are defined in torch/onnx/symbolic_helper.py
         do_constant_folding (bool, default False): If True, the constant-folding
             optimization is applied to the model during export. Constant-folding
             optimization will replace some of the ops that have all constant
@@ -237,7 +237,7 @@
                     example_outputs=None, propagate=False,
                     _retain_param_name=False, do_constant_folding=False,
                     _disable_torch_constant_prop=False):
-    from torch.onnx.symbolic import _export_onnx_opset_version
+    from torch.onnx.symbolic_helper import _export_onnx_opset_version
     # Special case for common case of passing a single Tensor
     if isinstance(args, torch.Tensor):
         args = (args, )
@@ -326,7 +326,7 @@
                              export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
                              google_printer=False, opset_version=None, _retain_param_name=False,
                              do_constant_folding=False):
-    from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
+    from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
     if opset_version is None:
         opset_version = _default_onnx_opset_version
     _set_opset_version(opset_version)
@@ -352,7 +352,7 @@
     assert __IN_ONNX_EXPORT is False
     __IN_ONNX_EXPORT = True
     try:
-        from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
+        from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
         if opset_version is None:
             opset_version = _default_onnx_opset_version
         _set_opset_version(opset_version)
@@ -553,7 +553,10 @@
     # NB: Returning None means the node gets cloned as is into
     # the new graph
     try:
-        import torch.onnx.symbolic
+        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
+        import torch.onnx.symbolic_registry as sym_registry
+
+        sym_registry.register_version('', opset_version)
 
         # See Note [Export inplace]
         # TODO: I think this is not necessary anymore
@@ -568,7 +571,7 @@
             return None
 
         elif ns == "aten":
-            is_exportable_aten_op = hasattr(torch.onnx.symbolic, op_name)
+            is_exportable_aten_op = sym_registry.is_registered_op(op_name, '', opset_version)
             is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
             is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
             if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export):
@@ -582,11 +585,11 @@
                 # Export it regularly
                 attrs = {k: n[k] for k in n.attributeNames()}
                 if not is_exportable_aten_op:
-                    warnings.warn("ONNX export failed on ATen operator {} because torch.onnx.symbolic.{} does not exist"
-                                  .format(op_name, op_name))
-                    return None
-                fn = getattr(torch.onnx.symbolic, op_name)
-                return fn(g, *inputs, **attrs)
+                    warnings.warn("ONNX export failed on ATen operator {} because "
+                                  "torch.onnx.symbolic_opset{}.{} does not exist"
+                                  .format(op_name, opset_version, op_name))
+                op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
+                return op_fn(g, *inputs, **attrs)
 
         elif ns == "prim":
             if op_name == "Constant" and not n.mustBeNone():
@@ -614,17 +617,32 @@
                     torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
                 return new_op_outputs
             else:
+                # TODO: we sould lift prim's symbolic out
                 symbolic_name = 'prim_' + op_name
-                symbolic_fn = getattr(torch.onnx.symbolic, symbolic_name, None)
-                if symbolic_fn is None:
+                is_exportable = sym_registry.is_registered_op(symbolic_name, '', opset_version)
+                if not is_exportable:
                     warnings.warn("ONNX export failed on primitive operator {}; please report a bug".format(op_name))
-                    return None
+                symbolic_fn = sym_registry.get_registered_op(symbolic_name, '', opset_version)
                 attrs = {k: n[k] for k in n.attributeNames()}
                 return symbolic_fn(g, *inputs, **attrs)
 
+        # custom ops
+        elif sym_registry.is_registered_version(ns, opset_version):
+            if not sym_registry.is_registered_op(op_name, ns, opset_version):
+                warnings.warn("ONNX export failed on custom operator {}::{} because "
+                              "torch.onnx.symbolic_opset{}.{} does not exist. "
+                              "Have you registered your symbolic function with "
+                              "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?"
+                              .format(ns, op_name, opset_version, op_name))
+            symbolic_fn = sym_registry.get_registered_op(symbolic_name, ns, opset_version)
+            attrs = {k: n[k] for k in n.attributeNames()}
+            return symbolic_fn(g, *inputs, **attrs)
+
         else:
             warnings.warn("ONNX export failed on an operator with unrecognized namespace {}::{}; "
-                          "please report a bug".format(ns, op_name))
+                          "If you are trying to export a custom operator, make sure you registered "
+                          "it with the right domain and version."
+                          "Otherwise please report a bug".format(ns, op_name))
             return None
 
     except TypeError as e:
@@ -686,6 +704,22 @@
     return getattr(self, sel)(k)
 
 
+def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+    if not bool(re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z]+[a-zA-Z0-9-_]*$", symbolic_name)):
+        raise RuntimeError("Failed to register operator {}. \
+                           The symbolic name must match the format Domain::Name, \
+                           and sould start with a letter and contain only \
+                           alphanumerical characters"
+                           .format(symbolic_name))
+    ns, op_name = symbolic_name.split('::')
+    unaccepted_domain_names = ["onnx", "aten", "prim"]
+    if ns in unaccepted_domain_names:
+        raise RuntimeError("Failed to register operator {}. The domain {} is already a used domain."
+                           .format(symbolic_name, ns))
+    import torch.onnx.symbolic_registry as sym_registry
+    sym_registry.register_op(op_name, symbolic_fn, ns, opset_version)
+
+
 torch._C.Graph.op = _graph_op
 torch._C.Graph.at = _graph_at
 torch._C.Graph.constant = _graph_constant