Revert "[ONNX] Refactor to remove inline imports (#77142)"
This reverts commit c08b8f0967efc2eec078da4541c5fdd003fbdd75.
Reverted https://github.com/pytorch/pytorch/pull/77142 on behalf of https://github.com/malfet
diff --git a/test/onnx/test_onnx_export.py b/test/onnx/test_onnx_export.py
index 6e955d1..52e32c5 100644
--- a/test/onnx/test_onnx_export.py
+++ b/test/onnx/test_onnx_export.py
@@ -5,7 +5,7 @@
import itertools
import os
import sys
-import unittest.mock
+import unittest
from typing import Callable, Iterable, Optional, Tuple, Union
import onnx
@@ -13,7 +13,6 @@
import torch
from torch.onnx import OperatorExportTypes, symbolic_registry
-from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _onnx_unsupported
from torch.testing._internal.common_utils import custom_op, skipIfCaffe2
@@ -30,9 +29,9 @@
Union[contextlib.AbstractContextManager, contextlib.ContextDecorator],
]
] = None,
- mocks: Optional[Iterable] = None,
+ mocks: Optional[Iterable[unittest.mock.patch]] = None,
operator_export_type: OperatorExportTypes = OperatorExportTypes.ONNX,
- opset_version: int = GLOBALS.export_onnx_opset_version,
+ opset_version: int = torch.onnx.symbolic_helper._export_onnx_opset_version,
) -> onnx.ModelProto:
"""Exports `model(input)` to ONNX and returns it.
diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py
index cd672ac..549666d 100644
--- a/test/onnx/test_onnx_opset.py
+++ b/test/onnx/test_onnx_opset.py
@@ -10,12 +10,10 @@
import torch.onnx
from torch.nn import Module
from torch.onnx import producer_name, producer_version
-from torch.onnx._globals import GLOBALS
+from torch.onnx.symbolic_helper import _export_onnx_opset_version
-def check_onnx_opset_operator(
- model, ops, opset_version=GLOBALS.export_onnx_opset_version
-):
+def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_version):
# check_onnx_components
assert (
model.producer_name == producer_name
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
index 38ac87d..c6279e1 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
@@ -9,18 +9,16 @@
skipIfUnsupportedMinOpsetVersion,
skipScriptTest,
)
-
-# TODO(justinchuby): Remove reference to other unit tests.
from test_pytorch_onnx_onnxruntime import TestONNXRuntime
import torch
from torch.cuda.amp import autocast
-from torch.onnx._globals import GLOBALS
class TestONNXRuntime_cuda(unittest.TestCase):
+ from torch.onnx.symbolic_helper import _export_onnx_opset_version
- opset_version = GLOBALS.export_onnx_opset_version
+ opset_version = _export_onnx_opset_version
keep_initializers_as_inputs = True
onnx_shape_inference = True
diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py
index 50e40ae..9e354fb 100644
--- a/test/onnx/test_pytorch_onnx_shape_inference.py
+++ b/test/onnx/test_pytorch_onnx_shape_inference.py
@@ -6,8 +6,11 @@
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
import torch
-from torch.onnx import _constants
-from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version
+from torch.onnx.symbolic_helper import (
+ _onnx_main_opset,
+ _set_onnx_shape_inference,
+ _set_opset_version,
+)
def expect_tensor(scalar_type, shape=None):
@@ -24,7 +27,7 @@
class TestONNXShapeInference(unittest.TestCase):
def __init__(self, *args, **kwargs):
unittest.TestCase.__init__(self, *args, **kwargs)
- self.opset_version = _constants.onnx_main_opset
+ self.opset_version = _onnx_main_opset
_set_onnx_shape_inference(True)
_set_opset_version(self.opset_version)
diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
index 9638e53..1c39140 100644
--- a/test/onnx/test_utility_funs.py
+++ b/test/onnx/test_utility_funs.py
@@ -2,6 +2,7 @@
import copy
import io
+import unittest
import onnx
import torchvision
@@ -33,6 +34,8 @@
parse_args,
)
+skip = unittest.skip
+
class _BaseTestCase(TestCase):
def setUp(self):
diff --git a/tools/onnx/update_default_opset_version.py b/tools/onnx/update_default_opset_version.py
index dfdbf1f..c7ecd59 100755
--- a/tools/onnx/update_default_opset_version.py
+++ b/tools/onnx/update_default_opset_version.py
@@ -78,8 +78,8 @@
read_sub_write(
- os.path.join("torch", "onnx", "_constants.py"),
- r"(onnx_default_opset = )\d+",
+ os.path.join("torch", "onnx", "symbolic_helper.py"),
+ r"(_default_onnx_opset_version = )\d+",
)
read_sub_write(
os.path.join("torch", "onnx", "__init__.py"), r"(opset_version \(int, default )\d+"
diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp
index 150eee9..2227842 100644
--- a/torch/csrc/jit/passes/onnx.cpp
+++ b/torch/csrc/jit/passes/onnx.cpp
@@ -242,7 +242,7 @@
::torch::onnx::OperatorExportTypes operator_export_type,
std::unordered_map<Value*, Value*>& env) {
py::object onnx = py::module::import("torch.onnx");
- py::object onnx_globals = py::module::import("torch.onnx._globals");
+ py::object onnx_symbolic = py::module::import("torch.onnx.symbolic_helper");
py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
// Setup all the lambda helper functions.
@@ -273,8 +273,8 @@
}
// For const node, it does not need params_dict info, so set it to {}.
const ParamMap empty_params_dict = {};
- auto opset_version = py::cast<int>(
- onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"));
+ auto opset_version =
+ py::cast<int>(onnx_symbolic.attr("_export_onnx_opset_version"));
for (const auto i : c10::irange(num_old_outputs)) {
auto old = old_outputs[i];
if (outputs[i]) {
@@ -435,8 +435,7 @@
pyobj = func->get();
}
- py::object opset_version =
- onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
+ py::object opset_version = onnx_symbolic.attr("_export_onnx_opset_version");
py::object is_registered_op = onnx_registry.attr("is_registered_op")(
"PythonOp", "prim", opset_version);
if (!py::hasattr(pyobj, "symbolic") &&
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 64bf14c..06541b4 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -1,4 +1,4 @@
-from typing import Dict
+from typing import Dict, Optional
import torch._C as _C
diff --git a/torch/onnx/_constants.py b/torch/onnx/_constants.py
deleted file mode 100644
index 0a7e853..0000000
--- a/torch/onnx/_constants.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""Constant values used in ONNX."""
-
-onnx_default_opset = 13
-onnx_main_opset = 16
-onnx_stable_opsets = tuple(range(7, onnx_main_opset))
-onnx_constant_folding_opsets = tuple(range(9, onnx_main_opset + 1))
diff --git a/torch/onnx/_globals.py b/torch/onnx/_globals.py
deleted file mode 100644
index 67d478e..0000000
--- a/torch/onnx/_globals.py
+++ /dev/null
@@ -1,48 +0,0 @@
-"""Globals used internally by the ONNX exporter.
-
-Do not use this module outside of `torch.onnx` and its tests.
-
-Be very judicious when adding any new global variables. Do not create new global
-variables unless they are absolutely necessary.
-"""
-from __future__ import annotations
-
-import typing
-from typing import Optional
-
-# This module should only depend on _constants and nothing else in torch.onnx to keep
-# dependency direction clean.
-from torch.onnx import _constants
-
-if typing.TYPE_CHECKING:
- # Postpone type checking to avoid circular dependencies.
- from torch.onnx import OperatorExportTypes, TrainingMode
-
-
-class _InternalGlobals:
- """Globals used internally by ONNX exporter.
-
- NOTE: Be very judicious when adding any new variables. Do not create new
- global variables unless they are absolutely necessary.
- """
-
- def __init__(self):
- self._export_onnx_opset_version = _constants.onnx_default_opset
- self.operator_export_type: Optional[OperatorExportTypes] = None
- self.training_mode: Optional[TrainingMode] = None
- self.onnx_shape_inference: bool = False
-
- @property
- def export_onnx_opset_version(self):
- return self._export_onnx_opset_version
-
- @export_onnx_opset_version.setter
- def export_onnx_opset_version(self, value: int):
- supported_versions = [_constants.onnx_main_opset]
- supported_versions.extend(_constants.onnx_stable_opsets)
- if value not in supported_versions:
- raise ValueError(f"Unsupported ONNX opset version: {value}")
- self._export_onnx_opset_version = value
-
-
-GLOBALS = _InternalGlobals()
diff --git a/torch/onnx/_patch_torch.py b/torch/onnx/_patch_torch.py
deleted file mode 100644
index 01af5fe..0000000
--- a/torch/onnx/_patch_torch.py
+++ /dev/null
@@ -1,238 +0,0 @@
-"""Importing this patches torch._C classes to add ONNX conveniences."""
-import numbers
-import re
-from typing import Iterable, Tuple, Union
-
-import torch
-from torch.onnx._globals import GLOBALS
-
-
-def _graph_op(
- g: torch._C.Graph,
- opname: str,
- *raw_args: torch._C.Node,
- outputs: int = 1,
- **kwargs,
-) -> Union[torch._C.Value, Tuple[torch._C.Value, ...]]:
- r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
-
- The set of operators and the inputs/attributes they take
- is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
-
- This function is monkey-patched onto Graph.
-
- Args:
- g: The Torch graph.
- opname: The ONNX operator name, e.g., `Abs` or `Add`. TODO(justinchu): Update examples to correct ones.
- raw_args: The inputs to the operator; usually provided
- as arguments to the `symbolic` definition.
- outputs: The number of outputs this operator returns.
- By default an operator is assumed to return a single output.
- If `outputs` is greater than one, this functions returns a tuple
- of output `Node`, representing each output of the ONNX operator
- in positional.
- kwargs: The attributes of the ONNX operator, whose keys are named
- according to the following convention: `alpha_f` indicates
- the `alpha` attribute with type `f`. The valid type specifiers are
- `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
- specified with type float accepts either a single float, or a
- list of floats (e.g., you would say `dims_i` for a `dims` attribute
- that takes a list of integers).
-
- Returns:
- The node representing the single output of this operator (see the `outputs`
- keyword argument for multi-return nodes).
- """
- # Filter out None attributes, this can be convenient client side because
- # now they can pass through None attributes, and have them not show up
- kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
-
- def const_if_tensor(arg):
- if arg is None:
- return arg
- elif isinstance(arg, torch._C.Value):
- return arg
- else:
- return g.op("Constant", value_z=arg) # type: ignore[attr-defined]
-
- args = [const_if_tensor(arg) for arg in raw_args]
- n = g.insertNode(_new_node(g, opname, outputs, *args, **kwargs)) # type: ignore[attr-defined]
-
- # Import utils to get _params_dict because it is a global that is accessed by c++ code
- from torch.onnx import utils
-
- if GLOBALS.onnx_shape_inference:
- torch._C._jit_pass_onnx_node_shape_type_inference(
- n, utils._params_dict, GLOBALS.export_onnx_opset_version
- )
-
- if outputs == 1:
- return n.output()
- return tuple(n.outputs())
-
-
-# Generate an ONNX ATen op node.
-def _aten_op(g, operator, *args, overload_name="", **kwargs):
- kwargs["aten"] = True
- return g.op(
- "ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs
- )
-
-
-def _block_op(b, opname, *args, **kwargs):
- if "::" in opname:
- aten = False
- ns_opname = opname
- else:
- aten = kwargs.pop("aten", False)
- ns = "aten" if aten else "onnx"
- ns_opname = ns + "::" + opname
- n = b.addNode(ns_opname, list(args))
- for k, v in sorted(kwargs.items()):
- # TODO: enable inplace in aten exporting mode.
- if k == "inplace":
- continue
- _add_attribute(n, k, v, aten=aten)
- if len(list(n.outputs())) == 1:
- return n.output()
- return tuple(o for o in n.outputs())
-
-
-def _new_node(g: torch._C.Graph, opname: str, outputs, *args, **kwargs):
- if "::" in opname:
- aten = False
- ns_opname = opname
- else:
- aten = kwargs.pop("aten", False)
- ns = "aten" if aten else "onnx"
- ns_opname = ns + "::" + opname
- n = g.create(ns_opname, args, outputs) # type: ignore[attr-defined]
- for k, v in sorted(kwargs.items()):
- # TODO: enable inplace in aten exporting mode.
- if k == "inplace":
- continue
- _add_attribute(n, k, v, aten=aten)
- return n
-
-
-_attr_pattern = re.compile("^(.+)_(([ifstgz])|(ty))$")
-
-
-def _is_onnx_list(value):
- return (
- not isinstance(value, torch._six.string_classes)
- and not isinstance(value, torch.Tensor)
- and isinstance(value, Iterable)
- )
-
-
-def _scalar(x):
- """Convert a scalar tensor into a Python value."""
- assert x.numel() == 1
- return x[0]
-
-
-def _is_caffe2_aten_fallback():
- return (
- GLOBALS.operator_export_type
- == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
- and torch.onnx._CAFFE2_ATEN_FALLBACK
- )
-
-
-def _add_attribute(node, key, value, aten):
- r"""Initializes the right attribute based on type of value."""
- m = _attr_pattern.match(key)
- if m is None:
- raise IndexError(
- (
- "Invalid attribute specifier '{}' names "
- + " must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
- ).format(key)
- )
- name, kind = m.group(1), m.group(2)
- if _is_onnx_list(value):
- kind += "s"
-
- if aten and _is_caffe2_aten_fallback():
- if isinstance(value, torch.Tensor):
- # Caffe2 proto does not support tensor attribute.
- if value.numel() > 1:
- raise ValueError("Should not pass tensor attribute")
- value = _scalar(value)
- if isinstance(value, float):
- kind = "f"
- else:
- kind = "i"
- return getattr(node, kind + "_")(name, value)
-
-
-# TODO: We might not need this anymore, since most scalars now show up as tensors
-# TODO(#76254): Remove the helper function if not needed.
-def _graph_constant(
- g,
- value,
- dims,
- type_: str,
- *args,
- **kwargs,
-):
- """This helper function can create either constant tensor or constant scalar.
-
- If dims is None or 0 or [0], generate a 0-d tensor (scalar).
- """
- assert isinstance(value, numbers.Number)
- assert type_ is not None
- isscalar = False
- if dims is None or dims == 0 or set(dims) == set([0]):
- dims = [1]
- isscalar = True
- type_ = type_.lower()
- tensor: Union[
- torch.CharTensor,
- torch.ShortTensor,
- torch.IntTensor,
- torch.LongTensor,
- torch.HalfTensor,
- torch.FloatTensor,
- torch.DoubleTensor,
- ]
- if type_ == "char":
- tensor = torch.CharTensor(*dims)
- elif type_ == "short":
- tensor = torch.ShortTensor(*dims)
- elif type_ == "int":
- tensor = torch.IntTensor(*dims)
- elif type_ == "long":
- tensor = torch.LongTensor(*dims)
- elif type_ == "half":
- tensor = torch.HalfTensor(*dims)
- elif type_ == "float":
- tensor = torch.FloatTensor(*dims)
- elif type_ == "double":
- tensor = torch.DoubleTensor(*dims)
- else:
- raise ValueError(
- "Unknown type, type should be one of the following strings: "
- "char, short, int, long, half, float, double"
- )
- tensor.fill_(value) # type: ignore[call-overload]
- if isscalar:
- return g.op("Constant", *args, value_z=tensor, **kwargs)
- return g.op("Constant", *args, value_t=tensor, **kwargs)
-
-
-def _node_getitem(self, k):
- """Gets attributes of a node which is polymorphic over return type.
-
- This is monkey-patched onto Node.
- """
- sel = self.kindOf(k)
- return getattr(self, sel)(k)
-
-
-torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
-torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
-torch._C.Block.op = _block_op # type: ignore[attr-defined]
-torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
-torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]
diff --git a/torch/onnx/onnx_supported_ops.py b/torch/onnx/onnx_supported_ops.py
index eacc77d..bc85029 100644
--- a/torch/onnx/onnx_supported_ops.py
+++ b/torch/onnx/onnx_supported_ops.py
@@ -2,11 +2,10 @@
from typing import Dict, List, Union
import torch._C
-from torch.onnx import _constants, symbolic_registry
+from torch.onnx import symbolic_helper, symbolic_registry
-for v in _constants.onnx_stable_opsets:
+for v in symbolic_helper._onnx_stable_opsets + [symbolic_helper._onnx_main_opset]:
symbolic_registry.register_version("", v)
-symbolic_registry.register_version("", _constants.onnx_main_opset)
class _TorchSchema:
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 59cd61e..f1a588a 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -7,12 +7,11 @@
import torch
import torch.onnx
-from torch._C import OptionalType
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
-from torch.onnx import _patch_torch # noqa: F401
-from torch.onnx._globals import GLOBALS
+import torch.onnx.utils
+from torch._C import OptionalType
# Note [Edit Symbolic Files]
# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
@@ -166,27 +165,24 @@
"""A decorator which converts args from torch._C.Value to built-in types.
For example:
-
- ```
@parse_args('v', 'i', 'fs')
foo(g, a, b, c):
- assert isinstance(a, torch._C.Value)
- assert isinstance(b, int)
- assert isinstance(c, list)
- assert isinstance(c[0], float)
- ```
+ assert isinstance(a, torch._C.Value)
+ assert isinstance(b, int)
+ assert isinstance(c, list)
+ assert isinstance(c[0], float)
Args:
- arg_descriptors: list of str, where each element is
- a string that specifies the type to convert to. Valid descriptors:
- "v": no conversion, keep torch._C.Value.
- "i": int
- "is": list of int
- "f": float
- "fs": list of float
- "b": bool
- "s": str
- "t": torch.Tensor
+ arg_descriptors: list of str, where each element is
+ a string that specifies the type to convert to. Valid descriptors:
+ "v": no conversion, keep torch._C.Value.
+ "i": int
+ "is": list(int)
+ "f": float
+ "fs": list of float
+ "b": bool
+ "s": str
+ "t": torch.Tensor
"""
def decorator(fn):
@@ -317,7 +313,7 @@
return x.item()
-def _if_scalar_type_as(g: torch._C.Graph, self, tensor):
+def _if_scalar_type_as(g, self, tensor):
"""
Convert self into the same type of tensor, as necessary.
@@ -380,8 +376,7 @@
def is_caffe2_aten_fallback():
return (
- GLOBALS.operator_export_type
- == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
+ _operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
and torch.onnx._CAFFE2_ATEN_FALLBACK
)
@@ -431,7 +426,7 @@
warnings.warn(
"ONNX export failed on " + op + " because " + msg + " not supported"
)
- elif GLOBALS.operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
+ elif _operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
_onnx_unsupported(f"{op}, {msg}")
@@ -467,7 +462,7 @@
raise RuntimeError(
"ONNX export failed on {}, which is not implemented for opset {}. "
"Try exporting with other opset versions.".format(
- name, GLOBALS.export_onnx_opset_version
+ name, _export_onnx_opset_version
)
)
@@ -503,7 +498,7 @@
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
- if GLOBALS.export_onnx_opset_version <= 9:
+ if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import _slice as _slice9
return _slice9(g, input, axes, starts, ends)
@@ -559,7 +554,7 @@
shape_,
g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)),
)
- if GLOBALS.export_onnx_opset_version <= 10:
+ if _export_onnx_opset_version <= 10:
if not decending:
_unimplemented("Sort", "Ascending is not supported")
return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2)
@@ -578,7 +573,7 @@
k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1])))
if _try_get_scalar_type(k) != "Long":
k = g.op("Cast", k, to_i=torch.onnx.TensorProtoDataType.INT64)
- if GLOBALS.export_onnx_opset_version <= 10:
+ if _export_onnx_opset_version <= 10:
if not largest:
_unimplemented("TopK", "Ascending is not supported")
return g.op("TopK", input, k, axis_i=dim, outputs=2)
@@ -589,7 +584,7 @@
def _lt_helper(g, input, other):
- if GLOBALS.export_onnx_opset_version <= 8:
+ if _export_onnx_opset_version <= 8:
from torch.onnx.symbolic_opset8 import lt as _lt8
return _lt8(g, input, other)
@@ -600,14 +595,12 @@
def _interpolate_warning(interpolate_mode):
- onnx_op = (
- "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
- )
+ onnx_op = "onnx:Resize" if _export_onnx_opset_version >= 10 else "onnx:Upsample"
warnings.warn(
"You are trying to export the model with "
+ onnx_op
+ " for ONNX opset version "
- "" + str(GLOBALS.export_onnx_opset_version) + ". "
+ "" + str(_export_onnx_opset_version) + ". "
"This operator might cause results to not match the expected results by PyTorch.\n"
"ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. "
"Attributes to determine how to transform the input were added in onnx:Resize in opset 11 "
@@ -618,12 +611,12 @@
def _unsqueeze_helper(g, input, axes_i):
if _is_constant(axes_i[0]):
- if GLOBALS.export_onnx_opset_version >= 13:
+ if _export_onnx_opset_version >= 13:
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
return g.op("Unsqueeze", input, axes)
return g.op("Unsqueeze", input, axes_i=axes_i)
# Tensor type
- if GLOBALS.export_onnx_opset_version < 13:
+ if _export_onnx_opset_version < 13:
raise ValueError(
f"Opset version must be >= 13 for Unsqueeze with dynamic axes. {input.node().sourceRange()}"
)
@@ -632,12 +625,12 @@
def _squeeze_helper(g, input, axes_i):
if _is_constant(axes_i[0]):
- if GLOBALS.export_onnx_opset_version >= 13:
+ if _export_onnx_opset_version >= 13:
axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
return g.op("Squeeze", input, axes)
return g.op("Squeeze", input, axes_i=axes_i)
# Tensor type
- if GLOBALS.export_onnx_opset_version < 13:
+ if _export_onnx_opset_version < 13:
raise ValueError(
f"Opset version must be >= 13 for Squeeze with dynamic axes. {input.node().sourceRange()}"
)
@@ -656,7 +649,7 @@
def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
keepdims_i = _maybe_get_const(keepdims_i, "i")
- if GLOBALS.export_onnx_opset_version >= 13:
+ if _export_onnx_opset_version >= 13:
if axes_i:
if not _is_value(axes_i):
axes_i = g.op(
@@ -800,7 +793,7 @@
output_size = g.op("Cast", output_size, to_i=cast_pytorch_to_onnx["Long"])
output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
- if GLOBALS.export_onnx_opset_version >= 13:
+ if _export_onnx_opset_version >= 13:
empty_roi = _optional_input_placeholder_tensor(g)
empty_scales = _optional_input_placeholder_tensor(g)
else:
@@ -823,7 +816,7 @@
nearest_mode_s="floor",
) # only valid when mode="nearest"
else:
- if GLOBALS.export_onnx_opset_version >= 13:
+ if _export_onnx_opset_version >= 13:
empty_roi = _optional_input_placeholder_tensor(g)
else:
empty_roi = g.op(
@@ -894,7 +887,7 @@
size = g.op("Cast", size, to_i=cast_pytorch_to_onnx["Long"])
size = g.op("Concat", input_size, size, axis_i=0)
- if GLOBALS.export_onnx_opset_version >= 13:
+ if _export_onnx_opset_version >= 13:
empty_roi = _optional_input_placeholder_tensor(g)
empty_scales = _optional_input_placeholder_tensor(g)
else:
@@ -919,7 +912,7 @@
if rank is None:
return _unimplemented("interpolate (with scales)", "missing input shape")
- if GLOBALS.export_onnx_opset_version >= 13:
+ if _export_onnx_opset_version >= 13:
empty_roi = _optional_input_placeholder_tensor(g)
else:
empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
@@ -938,9 +931,9 @@
def _unbind_helper(g, self, dim, _outputs):
- if GLOBALS.export_onnx_opset_version < 11:
+ if _export_onnx_opset_version < 11:
from torch.onnx.symbolic_opset9 import unbind
- elif GLOBALS.export_onnx_opset_version <= 12:
+ elif _export_onnx_opset_version <= 12:
from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
else:
from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef]
@@ -948,7 +941,7 @@
def _scatter_helper(g, self, dim, index, src):
- if GLOBALS.export_onnx_opset_version <= 10:
+ if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import scatter
else:
# for mypy, scatter was imported two lines above
@@ -957,7 +950,7 @@
def _repeat_interleave_split_helper(g, self, reps, dim):
- if GLOBALS.export_onnx_opset_version <= 12:
+ if _export_onnx_opset_version <= 12:
split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
else:
from torch.onnx.symbolic_opset13 import split
@@ -996,7 +989,7 @@
def _arange_helper(g, *args):
- if GLOBALS.export_onnx_opset_version <= 10:
+ if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import arange
else:
from torch.onnx.symbolic_opset11 import arange # type: ignore[no-redef]
@@ -1018,7 +1011,7 @@
from torch.onnx.symbolic_opset9 import expand
- if GLOBALS.export_onnx_opset_version <= 10:
+ if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import scatter
else:
# for mypy, scatter was imported two lines above
@@ -1047,10 +1040,10 @@
shape = _maybe_get_const(shape, "is")
if not _is_value(shape):
shape = g.op("Constant", value_t=torch.LongTensor(shape))
- if GLOBALS.export_onnx_opset_version <= 13:
+ if _export_onnx_opset_version <= 13:
if allowzero == 1:
raise _onnx_opset_unsupported(
- "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14
+ "Reshape with allowzero=1", _export_onnx_opset_version, 14
)
return g.op("Reshape", input, shape)
else:
@@ -1118,11 +1111,12 @@
def check_training_mode(op_train_mode, op_name):
+ global _training_mode
op_train_mode = True if op_train_mode == 1 else False
- if GLOBALS.training_mode is not None and op_train_mode != GLOBALS.training_mode:
+ if _training_mode is not None and op_train_mode != _training_mode:
op_mode = "training " if op_train_mode else "inference"
- training_mode = "training " if GLOBALS.training_mode else "inference"
- # setting the model mode could result in op_mode != _flags.training_mode
+ training_mode = "training " if _training_mode else "inference"
+ # setting the model mode could result in op_mode != _training_mode
# if the model is a FuncModule. In this case we warn the user of
# the state and export depending on op_mode
# This is to support use-cases of fixing certain layer weights
@@ -1213,10 +1207,10 @@
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
zero_point = g.op("Cast", zero_point, to_i=qdtype)
- if axis_i is not None and GLOBALS.export_onnx_opset_version < 13:
+ if axis_i is not None and _export_onnx_opset_version < 13:
_onnx_opset_unsupported_detailed(
"DequantizeLinear",
- GLOBALS.export_onnx_opset_version,
+ _export_onnx_opset_version,
13,
"Attribute axis is not supported.",
)
@@ -1238,16 +1232,12 @@
scale: torch._C.Value, quantized scale.
zero_point: torch._C.Value, quantized zero point.
axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization.
- Otherwise, represents per channel quantization, along given axis.
+ Otherwise, represents per channel quantization, along given axis.
"""
- if (
- axis is not None
- and not _is_none(axis)
- and GLOBALS.export_onnx_opset_version < 13
- ):
+ if axis is not None and not _is_none(axis) and _export_onnx_opset_version < 13:
_onnx_opset_unsupported_detailed(
"QuantizeLinear",
- GLOBALS.export_onnx_opset_version,
+ _export_onnx_opset_version,
13,
"Attribute axis is not supported.",
)
@@ -1299,23 +1289,43 @@
return has_same_dtype
-# TODO(justinchuby): Delete these setters, users should set the vars directly.
-def _set_opset_version(opset_version: int):
- GLOBALS.export_onnx_opset_version = opset_version
+_default_onnx_opset_version = 13
+_onnx_main_opset = 16
+_onnx_stable_opsets = list(range(7, _onnx_main_opset))
+_export_onnx_opset_version = _default_onnx_opset_version
+_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))
+
+
+def _set_opset_version(opset_version):
+ global _export_onnx_opset_version
+ if opset_version in _onnx_stable_opsets + [_onnx_main_opset]:
+ _export_onnx_opset_version = opset_version
+ return
+ raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
+
+
+_operator_export_type = None
def _set_operator_export_type(operator_export_type):
- GLOBALS.operator_export_type = operator_export_type
+ global _operator_export_type
+ _operator_export_type = operator_export_type
+
+
+_training_mode = None
def _set_training_mode(training_mode):
- GLOBALS.training_mode = training_mode
+ global _training_mode
+ _training_mode = training_mode
+_onnx_shape_inference = False
# This function is for debug use only.
-# onnx_shape_inference = False by default.
-def _set_onnx_shape_inference(onnx_shape_inference: bool):
- GLOBALS.onnx_shape_inference = onnx_shape_inference
+# onnx_shape_inference = True by default.
+def _set_onnx_shape_inference(onnx_shape_inference):
+ global _onnx_shape_inference
+ _onnx_shape_inference = onnx_shape_inference
# Metaprogram symbolics for each ATen native specialized cast operator.
diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py
index fe8e099..53ec445 100644
--- a/torch/onnx/symbolic_opset10.py
+++ b/torch/onnx/symbolic_opset10.py
@@ -1,17 +1,17 @@
-import sys
-import warnings
+# -*- coding: utf-8 -*-
+
+from sys import maxsize
import torch
import torch.onnx
import torch.onnx.symbolic_helper as sym_help
import torch.onnx.symbolic_opset9
-from torch.nn.modules.utils import _pair, _single, _triple
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
-from torch.onnx import _patch_torch # noqa: F401
-from torch.onnx._globals import GLOBALS
-from torch.onnx.symbolic_helper import parse_args, quantized_args
+import torch.onnx.utils
+from torch.nn.modules.utils import _pair, _single, _triple
+from torch.onnx.symbolic_helper import _unimplemented, parse_args, quantized_args
from torch.onnx.symbolic_opset9 import (
add,
conv2d,
@@ -202,7 +202,7 @@
sym_help._interpolate_warning(interpolate_mode)
align_corners = sym_help._maybe_get_scalar(align_corners)
if align_corners:
- return sym_help._unimplemented(name, "align_corners == True")
+ return _unimplemented(name, "align_corners == True")
if scales is None:
scales = sym_help._interpolate_size_to_scales(g, input, output_size, dim)
return g.op("Resize", input, scales, mode_s=interpolate_mode)
@@ -328,12 +328,13 @@
include_last_offset,
padding_idx,
):
- if scale_grad_by_freq and GLOBALS.training_mode:
+ if scale_grad_by_freq and sym_help._training_mode:
return sym_help._onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
if padding_idx is not None and padding_idx >= 0:
raise RuntimeError("embedding_bag with padding_idx")
+ import warnings
from torch.onnx.symbolic_opset9 import select
@@ -350,7 +351,7 @@
offset_len = offsets_dim_0
offsets_extended = [
offsets,
- g.op("Constant", value_t=torch.tensor([sys.maxsize])),
+ g.op("Constant", value_t=torch.tensor([maxsize])),
]
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
list_ = []
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index b3dc019..6d96135 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -1,11 +1,12 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
import warnings
from sys import maxsize
import torch
import torch.onnx.symbolic_helper as sym_help
-import torch.onnx.utils
from torch.nn.modules.utils import _pair, _single, _triple
-from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import (
ScalarType,
_is_tensor_list,
@@ -16,6 +17,11 @@
from torch.onnx.symbolic_opset9 import _pad_circular, expand
from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn
from torch.onnx.symbolic_opset9 import mul, op_with_optional_float_cast, unused
+from torch.onnx.utils import (
+ _add_block,
+ _add_input_to_block,
+ _add_output_to_block,
+)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@@ -1072,7 +1078,7 @@
include_last_offset,
padding_idx,
):
- if scale_grad_by_freq and GLOBALS.training_mode:
+ if scale_grad_by_freq and sym_help._training_mode:
return sym_help._onnx_unsupported(
"embedding_bag with scale_grad_by_freq for training mode"
)
@@ -1107,9 +1113,9 @@
)
loop = g.op("Loop", loop_len, loop_condition)
- loop_block = torch.onnx.utils._add_block(loop.node())
- block_input_iter = torch.onnx.utils._add_input_to_block(loop_block)
- cond = torch.onnx.utils._add_input_to_block(loop_block)
+ loop_block = _add_block(loop.node())
+ block_input_iter = _add_input_to_block(loop_block)
+ cond = _add_input_to_block(loop_block)
indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0)
indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0)
@@ -1136,8 +1142,8 @@
embeddings = loop_block.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
- torch.onnx.utils._add_output_to_block(loop_block, cond_out)
- torch.onnx.utils._add_output_to_block(loop_block, embeddings)
+ _add_output_to_block(loop_block, cond_out)
+ _add_output_to_block(loop_block, embeddings)
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py
index df8e6bc..3825eb2 100644
--- a/torch/onnx/symbolic_opset12.py
+++ b/torch/onnx/symbolic_opset12.py
@@ -3,9 +3,13 @@
import torch
import torch.onnx.symbolic_helper as sym_help
-import torch.onnx.utils
from torch.onnx.symbolic_helper import _parse_arg, _unimplemented, parse_args
from torch.onnx.symbolic_opset9 import _reshape_from_tensor, permute
+from torch.onnx.utils import (
+ _add_block,
+ _add_input_to_block,
+ _add_output_to_block,
+)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@@ -270,9 +274,9 @@
loop_len = g.op("Min", low_size, hi_size)
loop = g.op("Loop", loop_len, loop_condition)
- loop_block = torch.onnx.utils._add_block(loop.node())
- block_input_iter = torch.onnx.utils._add_input_to_block(loop_block)
- cond = torch.onnx.utils._add_input_to_block(loop_block)
+ loop_block = _add_block(loop.node())
+ block_input_iter = _add_input_to_block(loop_block)
+ cond = _add_input_to_block(loop_block)
starts = loop_block.op("Gather", low_indices, block_input_iter)
ends = loop_block.op("Gather", hi_indices, block_input_iter)
@@ -288,8 +292,8 @@
concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0)
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
- torch.onnx.utils._add_output_to_block(loop_block, cond_out)
- torch.onnx.utils._add_output_to_block(loop_block, concat)
+ _add_output_to_block(loop_block, cond_out)
+ _add_output_to_block(loop_block, concat)
loop_output = loop.node().output()
perm = [0, 1, 2, 3, 4]
diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py
index 6167e40..6de4f86 100644
--- a/torch/onnx/symbolic_opset13.py
+++ b/torch/onnx/symbolic_opset13.py
@@ -4,7 +4,6 @@
# This file exports ONNX ops for opset 13
import torch
import torch.onnx.symbolic_helper as sym_help
-import torch.onnx.utils
from torch.onnx.symbolic_helper import _unimplemented, parse_args
from torch.onnx.symbolic_opset9 import (
_maybe_cast_reduce_op_input,
@@ -20,6 +19,11 @@
zeros,
)
from torch.onnx.symbolic_opset11 import unsqueeze
+from torch.onnx.utils import (
+ _add_block,
+ _add_input_to_block,
+ _add_output_to_block,
+)
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
@@ -387,10 +391,10 @@
loop = g.op("Loop", loop_len, loop_condition, final_splits)
# Loop inputs
- loop_block = torch.onnx.utils._add_block(loop.node())
- block_input_iter = torch.onnx.utils._add_input_to_block(loop_block)
- cond = torch.onnx.utils._add_input_to_block(loop_block)
- final_splits = torch.onnx.utils._add_input_to_block(loop_block)
+ loop_block = _add_block(loop.node())
+ block_input_iter = _add_input_to_block(loop_block)
+ cond = _add_input_to_block(loop_block)
+ final_splits = _add_input_to_block(loop_block)
r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
@@ -410,8 +414,8 @@
# Loop outputs
cond_out = loop_block.op("Cast", loop_condition, to_i=9)
- torch.onnx.utils._add_output_to_block(loop_block, cond_out)
- torch.onnx.utils._add_output_to_block(loop_block, final_splits)
+ _add_output_to_block(loop_block, cond_out)
+ _add_output_to_block(loop_block, final_splits)
loop_out = loop.node().output()
loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
@@ -514,7 +518,7 @@
if_op = g.op("If", overrun_cond)
if_node = if_op.node()
- if_block = torch.onnx.utils._add_block(if_node)
+ if_block = _add_block(if_node)
gather_indices_if_block = if_block.op("Add", gather_indices, select_window)
gather_indices_if_block = sym_help._unsqueeze_helper(
if_block, gather_indices_if_block, [rank - 1]
@@ -522,11 +526,11 @@
final_non_overrun_ = if_block.op(
"GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
)
- torch.onnx.utils._add_output_to_block(if_block, final_non_overrun_)
+ _add_output_to_block(if_block, final_non_overrun_)
- else_block = torch.onnx.utils._add_block(if_node)
+ else_block = _add_block(if_node)
final_overrun_ = zeros(else_block, gather_shape, 6, None, None)
- torch.onnx.utils._add_output_to_block(else_block, final_overrun_)
+ _add_output_to_block(else_block, final_overrun_)
return if_op
diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py
index b46b635..fe140fe 100644
--- a/torch/onnx/symbolic_opset14.py
+++ b/torch/onnx/symbolic_opset14.py
@@ -4,8 +4,7 @@
# This file exports ONNX ops for opset 14
import torch
import torch.onnx.symbolic_helper as sym_help
-from torch.onnx._globals import GLOBALS
-from torch.onnx.symbolic_helper import parse_args
+from torch.onnx.symbolic_helper import args_have_same_dtype, parse_args
# Note [ONNX operators that are added/updated in opset 14]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -59,10 +58,8 @@
if (
torch.is_autocast_enabled()
- and not sym_help.args_have_same_dtype(
- [input, weight, bias, running_mean, running_var]
- )
- and GLOBALS.export_onnx_opset_version < 15
+ and not args_have_same_dtype([input, weight, bias, running_mean, running_var])
+ and sym_help._export_onnx_opset_version < 15
):
return sym_help._onnx_opset_unsupported_detailed(
"BatchNormalization",
diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py
index 6f2da64..72aa535 100644
--- a/torch/onnx/symbolic_opset8.py
+++ b/torch/onnx/symbolic_opset8.py
@@ -10,6 +10,7 @@
_unimplemented,
parse_args,
)
+from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore[attr-defined]
# Note [ONNX operators that are added/updated from opset 8 to opset 9]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -128,11 +129,7 @@
if arg0_type is not None:
old_type = arg0_type
if old_type not in floating_scalar_types:
- # TODO(justinchuby): Remove the type ignore hint once _cast_Float is
- # properly defined.
- # NOTE: _cast_Float is generated programmatically so we need to make the
- # type checker happy with ignore[attr-defined].
- args = tuple(sym_opset9._cast_Float(g, arg, False) for arg in args) # type: ignore[attr-defined]
+ args = tuple(_cast_Float(g, arg, False) for arg in args)
else:
return (None,) + args
else:
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index ec27e8d..e6c5978 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1,3 +1,5 @@
+# -*- coding: utf-8 -*-
+
import math
import warnings
from functools import partial, wraps
@@ -7,13 +9,12 @@
import torch
import torch.onnx
import torch.onnx.symbolic_helper as sym_help
-from torch._C import ListType, OptionalType
-from torch.nn.modules.utils import _pair, _single, _triple
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
-from torch.onnx import _patch_torch # noqa: F401
-from torch.onnx._globals import GLOBALS
+import torch.onnx.utils
+from torch._C import ListType, OptionalType
+from torch.nn.modules.utils import _pair, _single, _triple
from torch.onnx.symbolic_helper import (
ScalarType,
_parse_arg,
@@ -528,12 +529,12 @@
@parse_args("v", "v", "i", "b", "v")
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
- if scale_grad_by_freq and GLOBALS.training_mode:
+ if scale_grad_by_freq and sym_help._training_mode:
raise RuntimeError(
"Unsupported: ONNX export of embedding with scale_grad_by_freq=True "
"for training mode. ONNX does not support scaling the gradients."
)
- if padding_idx >= 0 and GLOBALS.training_mode:
+ if padding_idx >= 0 and sym_help._training_mode:
warnings.warn(
"Warning: ONNX export of embedding with padding_idx >= 0 "
"for training mode. "
@@ -837,7 +838,7 @@
dtype_0 = inputs[0].type().scalarType()
require_cast = not sym_help._is_fp(inputs[0]) and (
- opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
+ opset_before is None or sym_help._export_onnx_opset_version < opset_before
)
if require_cast:
@@ -1549,7 +1550,7 @@
if condition.type().scalarType() != "Bool":
condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx["Bool"])
if self is None:
- condition = nonzero(g, condition)
+ condition = torch.onnx.symbolic_opset9.nonzero(g, condition)
return sym_help._unbind_helper(
g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
)
@@ -1791,7 +1792,7 @@
if (
torch.is_autocast_enabled()
and not args_have_same_dtype([input, weight, bias, running_mean, running_var])
- and GLOBALS.export_onnx_opset_version < 15
+ and sym_help._export_onnx_opset_version < 15
):
return sym_help._onnx_opset_unsupported_detailed(
"BatchNormalization",
@@ -2371,8 +2372,6 @@
sym_help._onnx_opset_unsupported("_unique2", 9, 11)
-# TODO(justinchuby): Clean up this function generation magic by defining the functions
-# explicitly.
for k, v in sym_help.cast_pytorch_to_onnx.items():
name = "_cast_{}".format(k)
globals()[name] = parse_args("v", "i")(partial(sym_help._cast_func_template, v))
@@ -2610,7 +2609,7 @@
or ((not is_end_none) and (not is_end_onnx_const))
or dim.node().kind() != "onnx::Constant"
):
- if GLOBALS.operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
+ if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
raise RuntimeError(
"Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
"is a deprecated experimental op. Please use statically allocated "
@@ -3953,7 +3952,7 @@
if not sym_help._is_none(index) and (
index.type().scalarType() == "Byte" or index.type().scalarType() == "Bool"
):
- if GLOBALS.export_onnx_opset_version < 9:
+ if sym_help._export_onnx_opset_version < 9:
raise RuntimeError(
"Exporting masked indices are only supported after ONNX opset 9."
)
@@ -4009,7 +4008,7 @@
# update the warning to recommend exporting with higher opset version.
warnings.warn(
"Exporting aten::index operator of advanced indexing in opset "
- + str(GLOBALS.export_onnx_opset_version)
+ + str(sym_help._export_onnx_opset_version)
+ " is achieved by combination of multiple ONNX operators, "
+ "including Reshape, Transpose, Concat, and Gather. "
+ "If indices include negative values, the exported graph will produce incorrect results."
@@ -4827,8 +4826,8 @@
env = ctx.env
params_dict = ctx.params_dict
- operator_export_type = GLOBALS.operator_export_type
- opset_version = GLOBALS.export_onnx_opset_version
+ operator_export_type = sym_help._operator_export_type
+ opset_version = sym_help._export_onnx_opset_version
new_op_outputs = g.op("Loop", *inputs, outputs=n.outputsSize())
new_node = (
@@ -4861,7 +4860,9 @@
new_node, opset_version
)
# Run shape type inference for Loop after subblock is converted.
- if GLOBALS.onnx_shape_inference:
+ from torch.onnx.symbolic_helper import _onnx_shape_inference
+
+ if _onnx_shape_inference:
torch._C._jit_pass_onnx_node_shape_type_inference(
new_node, params_dict, opset_version
)
@@ -4874,8 +4875,8 @@
env = ctx.env
params_dict = ctx.params_dict
- operator_export_type = GLOBALS.operator_export_type
- opset_version = GLOBALS.export_onnx_opset_version
+ operator_export_type = sym_help._operator_export_type
+ opset_version = sym_help._export_onnx_opset_version
static_if = inputs[0].node().kind() == "onnx::Constant"
if static_if:
@@ -4950,7 +4951,9 @@
new_node, opset_version
)
# Run shape type inference for If after subblock is converted.
- if GLOBALS.onnx_shape_inference:
+ from torch.onnx.symbolic_helper import _onnx_shape_inference
+
+ if _onnx_shape_inference:
torch._C._jit_pass_onnx_node_shape_type_inference(
new_node, params_dict, opset_version
)
diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py
index 88286c8..d90772d1 100644
--- a/torch/onnx/symbolic_registry.py
+++ b/torch/onnx/symbolic_registry.py
@@ -1,13 +1,12 @@
import importlib
import inspect
-import itertools
import warnings
from typing import Any, Callable, Dict, Tuple, Union
import torch._C
-from torch.onnx import _constants
+from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
-_SymbolicFunction = Callable[..., Union[torch._C.Value, Tuple[torch._C.Value]]]
+SymbolicFunction = Callable[..., Union[torch._C.Value, Tuple[torch._C.Value]]]
"""
The symbolic registry "_registry" is a dictionary that maps operators
@@ -19,21 +18,16 @@
"""
_registry: Dict[
Tuple[str, int],
- Dict[str, _SymbolicFunction],
+ Dict[str, SymbolicFunction],
] = {}
_symbolic_versions: Dict[Union[int, str], Any] = {}
-
-def _import_symbolic_opsets():
- for opset_version in itertools.chain(
- _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
- ):
- module = importlib.import_module(
- "torch.onnx.symbolic_opset{}".format(opset_version)
- )
- global _symbolic_versions
- _symbolic_versions[opset_version] = module
+for opset_version in _onnx_stable_opsets + [_onnx_main_opset]:
+ module = importlib.import_module(
+ "torch.onnx.symbolic_opset{}".format(opset_version)
+ )
+ _symbolic_versions[opset_version] = module
def register_version(domain: str, version: int):
@@ -135,7 +129,7 @@
def get_op_supported_version(opname: str, domain: str, version: int):
iter_version = version
- while iter_version <= _constants.onnx_main_opset:
+ while iter_version <= _onnx_main_opset:
ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
if (domain, opname) in ops:
return iter_version
@@ -143,7 +137,7 @@
return None
-def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
+def get_registered_op(opname: str, domain: str, version: int) -> SymbolicFunction:
if domain is None or version is None:
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
global _registry
@@ -155,7 +149,7 @@
class UnsupportedOperatorError(RuntimeError):
def __init__(self, domain: str, opname: str, version: int):
supported_version = get_op_supported_version(opname, domain, version)
- if domain in {"", "aten", "prim", "quantized"}:
+ if domain in ["", "aten", "prim", "quantized"]:
msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. "
if supported_version is not None:
msg += (
@@ -171,6 +165,3 @@
"it with the right domain and version."
)
super().__init__(msg)
-
-
-_import_symbolic_opsets()
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 30b2c6a..cb02ee4 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -6,26 +6,16 @@
import contextlib
import copy
import inspect
-import itertools
+import numbers
import os
import re
import textwrap
import typing
import warnings
-import zipfile
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
-import torch.jit._trace
import torch.serialization
-from torch.onnx import ( # noqa: F401
- _constants,
- _patch_torch,
- symbolic_caffe2,
- symbolic_helper,
- symbolic_registry,
-)
-from torch.onnx._globals import GLOBALS
# the flag to tell the user whether it's in the middle of ONNX export or not
__IN_ONNX_EXPORT = False
@@ -36,7 +26,6 @@
return __IN_ONNX_EXPORT
-# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
# Skip check due to cannot import IValue from torch._C
_params_dict = {} # type: ignore[var-annotated]
@@ -66,24 +55,26 @@
if mode == torch.onnx.TrainingMode.TRAINING or (
mode == torch.onnx.TrainingMode.PRESERVE and is_originally_training
):
+ from torch.onnx.symbolic_helper import _export_onnx_opset_version
- if GLOBALS.export_onnx_opset_version < 12:
+ if _export_onnx_opset_version < 12:
warnings.warn(
"You are exporting the model in training mode with onnx opset version {}. "
"Opset versions lower than opset 12 will not be able to export nodes such as "
"Dropout and BatchNorm correctly.".format(
- GLOBALS.export_onnx_opset_version
+ _export_onnx_opset_version
)
)
is_export_training = True
- symbolic_helper._set_training_mode(is_export_training)
+ from torch.onnx.symbolic_helper import _set_training_mode
+
+ _set_training_mode(is_export_training)
model.train(is_export_training)
try:
yield
finally:
if not isinstance(model, torch.jit.ScriptFunction):
- # FIXME(justinchuby): is_originally_training is possibly unbound
model.train(is_originally_training)
@@ -108,7 +99,6 @@
yield
finally:
if not isinstance(model, torch.jit.ScriptFunction):
- # FIXME(justinchuby): tmp_map is possibly unbound
for module, m_map in tmp_map.items():
for k, v in m_map.items():
module._state_dict_hooks[k] = v
@@ -247,6 +237,11 @@
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_fuse_addmm(graph)
torch._C._jit_pass_lint(graph)
+ from torch.onnx.symbolic_helper import (
+ _export_onnx_opset_version,
+ _onnx_shape_inference,
+ is_caffe2_aten_fallback,
+ )
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lower_all_tuples(graph)
@@ -268,12 +263,12 @@
torch._C._jit_pass_onnx_remove_print(graph)
torch._C._jit_pass_onnx_preprocess_caffe2(graph)
- symbolic_helper._quantized_ops.clear()
+ torch.onnx.symbolic_helper._quantized_ops.clear()
# Unpack quantized weights for conv and linear ops and insert into graph.
torch._C._jit_pass_onnx_unpack_quantized_weights(
- graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
+ graph, params_dict, is_caffe2_aten_fallback()
)
- if symbolic_helper.is_caffe2_aten_fallback():
+ if is_caffe2_aten_fallback():
# Insert permutes before and after each conv op to ensure correct order.
torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
@@ -296,7 +291,7 @@
# onnx only supports tensors, so we turn all out number types into tensors
torch._C._jit_pass_erase_number_types(graph)
- if GLOBALS.onnx_shape_inference:
+ if _onnx_shape_inference:
input_names = [] if input_names is None else input_names
dynamic_axes = {} if dynamic_axes is None else dynamic_axes
torch._C._jit_pass_onnx_set_dynamic_input_shape(
@@ -308,12 +303,12 @@
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_onnx_scalar_type_analysis(
- graph, True, GLOBALS.export_onnx_opset_version
+ graph, True, _export_onnx_opset_version
)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_onnx_peephole(
- graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
+ graph, _export_onnx_opset_version, fixed_batch_size
)
torch._C._jit_pass_lint(graph)
@@ -326,9 +321,9 @@
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
- if GLOBALS.onnx_shape_inference:
+ if _onnx_shape_inference:
torch._C._jit_pass_onnx_graph_shape_type_inference(
- graph, params_dict, GLOBALS.export_onnx_opset_version
+ graph, params_dict, _export_onnx_opset_version
)
return graph
@@ -591,7 +586,6 @@
graph = model.graph
torch._C._jit_pass_onnx_function_substitution(graph)
param_count_list = _get_param_count_list(graph, args)
- # FIXME(justinchuby): flattened_args is possibly unbound
graph = torch._C._propagate_and_assign_input_shapes(
graph, flattened_args, param_count_list, False, False
)
@@ -717,6 +711,7 @@
this will be None, since we are not doing any tracing.
"""
# TODO: can we simplify this to always return a tuple of Tensor or None?
+ from torch.onnx.symbolic_helper import _export_onnx_opset_version
# Special case for common case of passing a single Tensor
if isinstance(args, (torch.Tensor, int, float, bool)):
@@ -740,6 +735,7 @@
except Exception as e:
torch.onnx.log("Torch IR graph at exception: ", graph)
raise
+ from torch.onnx.symbolic_helper import _onnx_shape_inference
is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
if is_script:
@@ -749,7 +745,7 @@
example_outputs_final += unpack_quantized_tensor(example_output)
out_vars, desc = torch.jit._flatten(example_outputs_final)
torch._C._jit_pass_onnx_assign_output_shape(
- graph, out_vars, desc, GLOBALS.onnx_shape_inference, is_script
+ graph, out_vars, desc, _onnx_shape_inference, is_script
)
# NB: ONNX requires complete information about output types, which might be
@@ -766,11 +762,7 @@
# single value in PyTorch.
if not any(getattr(out, "is_quantized", False) for out in output_tensors):
torch._C._jit_pass_onnx_assign_output_shape(
- graph,
- output_tensors,
- out_desc,
- GLOBALS.onnx_shape_inference,
- is_script,
+ graph, output_tensors, out_desc, _onnx_shape_inference, is_script
)
_set_input_and_output_names(graph, input_names, output_names)
@@ -779,25 +771,27 @@
if training is None or training == torch.onnx.TrainingMode.EVAL:
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
+ from torch.onnx.symbolic_helper import _constant_folding_opset_versions
+
if (
do_constant_folding
- and GLOBALS.export_onnx_opset_version in _constants.onnx_constant_folding_opsets
+ and _export_onnx_opset_version in _constant_folding_opset_versions
):
params_dict = torch._C._jit_pass_onnx_constant_fold(
- graph, params_dict, GLOBALS.export_onnx_opset_version
+ graph, params_dict, _export_onnx_opset_version
)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
- if GLOBALS.onnx_shape_inference:
+ if _onnx_shape_inference:
torch._C._jit_pass_onnx_graph_shape_type_inference(
- graph, params_dict, GLOBALS.export_onnx_opset_version
+ graph, params_dict, _export_onnx_opset_version
)
params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
# For ONNX opset < 9, constants only have three data types: float16, float, double.
# In this pass transform constants of other data types to float/double + cast operator.
- if GLOBALS.export_onnx_opset_version < 9:
+ if _export_onnx_opset_version < 9:
torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph)
params_dict = torch._C._jit_pass_filter_non_tensor_arguments(params_dict)
@@ -828,15 +822,21 @@
do_constant_folding=True,
dynamic_axes=None,
):
+ from torch.onnx.symbolic_helper import (
+ _default_onnx_opset_version,
+ _set_operator_export_type,
+ _set_opset_version,
+ )
if opset_version is None:
- opset_version = _constants.onnx_default_opset
+ opset_version = _default_onnx_opset_version
if custom_opsets is None:
custom_opsets = {}
- symbolic_helper._set_opset_version(opset_version)
- symbolic_helper._set_operator_export_type(operator_export_type)
+ _set_opset_version(opset_version)
+ _set_operator_export_type(operator_export_type)
+ from torch.onnx.symbolic_helper import _set_onnx_shape_inference
- symbolic_helper._set_onnx_shape_inference(True)
+ _set_onnx_shape_inference(True)
with exporter_context(model, training, verbose):
val_keep_init_as_ip = _decide_keep_init_as_input(
keep_initializers_as_inputs, operator_export_type, opset_version
@@ -890,9 +890,13 @@
Tuple[torch._C.Graph, List[str]], where the list includes the names
of the unconvertible ops.
"""
+ from torch.onnx.symbolic_helper import (
+ _default_onnx_opset_version,
+ _set_opset_version,
+ )
- opset_version = opset_version or _constants.onnx_default_opset
- symbolic_helper._set_opset_version(opset_version)
+ opset_version = opset_version or _default_onnx_opset_version
+ _set_opset_version(opset_version)
# operator_export_type is set to ONNX_FALLTHROUGH by default so that if an op is not supported
# in ONNX, fall through will occur and export the operator as is, as a custom ONNX op.
with exporter_context(model, training, False):
@@ -977,9 +981,10 @@
def _get_module_attributes(module):
+ from typing import get_type_hints
- annotations = typing.get_type_hints(type(module))
- base_m_annotations = typing.get_type_hints(torch.nn.Module)
+ annotations = get_type_hints(type(module))
+ base_m_annotations = get_type_hints(torch.nn.Module)
[annotations.pop(k, None) for k in base_m_annotations]
return {k: getattr(module, k) for k in annotations}
@@ -1017,11 +1022,18 @@
assert __IN_ONNX_EXPORT is False
__IN_ONNX_EXPORT = True
try:
+ from torch.onnx.symbolic_helper import _set_onnx_shape_inference
- symbolic_helper._set_onnx_shape_inference(onnx_shape_inference)
+ _set_onnx_shape_inference(onnx_shape_inference)
+
+ from torch.onnx.symbolic_helper import (
+ _default_onnx_opset_version,
+ _set_operator_export_type,
+ _set_opset_version,
+ )
if opset_version is None:
- opset_version = _constants.onnx_default_opset
+ opset_version = _default_onnx_opset_version
if export_modules_as_functions and opset_version < 15:
raise ValueError(
@@ -1045,8 +1057,8 @@
# If you really know what you're doing, you can turn
# training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
# (to preserve whatever the original training mode was.)
- symbolic_helper._set_opset_version(opset_version)
- symbolic_helper._set_operator_export_type(operator_export_type)
+ _set_opset_version(opset_version)
+ _set_operator_export_type(operator_export_type)
with exporter_context(model, training, verbose):
val_keep_init_as_ip = _decide_keep_init_as_input(
keep_initializers_as_inputs, operator_export_type, opset_version
@@ -1148,6 +1160,8 @@
torch.onnx.ExportTypes.ZIP_ARCHIVE,
torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
]:
+ import zipfile
+
compression = (
zipfile.ZIP_DEFLATED
if export_type == torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE
@@ -1243,6 +1257,9 @@
set_names(list(graph.outputs()), output_names, "output")
+_attr_pattern = re.compile("^(.+)_(([ifstgz])|(ty))$")
+
+
def _run_symbolic_method(g, op_name, symbolic_fn, args):
r"""
This trampoline function gets invoked for every symbolic method
@@ -1258,6 +1275,151 @@
raise
+def _is_onnx_list(value):
+ return (
+ not isinstance(value, torch._six.string_classes)
+ and not isinstance(value, torch.Tensor)
+ and isinstance(value, Iterable)
+ )
+
+
+def _add_attribute(node, key, value, aten):
+ r"""Initializes the right attribute based on type of value."""
+ m = _attr_pattern.match(key)
+ if m is None:
+ raise IndexError(
+ (
+ "Invalid attribute specifier '{}' names "
+ + " must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
+ ).format(key)
+ )
+ name, kind = m.group(1), m.group(2)
+ if _is_onnx_list(value):
+ kind += "s"
+ from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
+
+ if aten and is_caffe2_aten_fallback():
+ if isinstance(value, torch.Tensor):
+ # Caffe2 proto does not support tensor attribute.
+ if value.numel() > 1:
+ raise ValueError("Should not pass tensor attribute")
+ value = _scalar(value)
+ if isinstance(value, float):
+ kind = "f"
+ else:
+ kind = "i"
+ return getattr(node, kind + "_")(name, value)
+
+
+def _scalar(x):
+ """Convert a scalar tensor into a Python value."""
+ assert x.numel() == 1
+ return x[0]
+
+
+def _new_node(g: torch._C.Graph, opname: str, outputs, *args, **kwargs):
+ if "::" in opname:
+ aten = False
+ ns_opname = opname
+ else:
+ aten = kwargs.pop("aten", False)
+ ns = "aten" if aten else "onnx"
+ ns_opname = ns + "::" + opname
+ n = g.create(ns_opname, args, outputs) # type: ignore[attr-defined]
+ for k, v in sorted(kwargs.items()):
+ # TODO: enable inplace in aten exporting mode.
+ if k == "inplace":
+ continue
+ _add_attribute(n, k, v, aten=aten)
+ return n
+
+
+def _graph_op(
+ g: torch._C.Graph,
+ opname: str,
+ *raw_args: torch._C.Node,
+ outputs: int = 1,
+ **kwargs,
+) -> Union[torch._C.Value, Tuple[torch._C.Value, ...]]:
+ r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
+
+ The set of operators and the inputs/attributes they take
+ is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
+
+ This function is monkey-patched onto Graph.
+
+ Args:
+ g: The Torch graph.
+ opname: The ONNX operator name, e.g., `Abs` or `Add`. TODO(justinchu): Update examples to correct ones.
+ raw_args: The inputs to the operator; usually provided
+ as arguments to the `symbolic` definition.
+ outputs: The number of outputs this operator returns.
+ By default an operator is assumed to return a single output.
+ If `outputs` is greater than one, this functions returns a tuple
+ of output `Node`, representing each output of the ONNX operator
+ in positional.
+ kwargs: The attributes of the ONNX operator, whose keys are named
+ according to the following convention: `alpha_f` indicates
+ the `alpha` attribute with type `f`. The valid type specifiers are
+ `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
+ specified with type float accepts either a single float, or a
+ list of floats (e.g., you would say `dims_i` for a `dims` attribute
+ that takes a list of integers).
+
+ Returns:
+ The node representing the single output of this operator (see the `outputs`
+ keyword argument for multi-return nodes).
+ """
+ # Filter out None attributes, this can be convenient client side because
+ # now they can pass through None attributes, and have them not show up
+ kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
+
+ def const_if_tensor(arg):
+ if arg is None:
+ return arg
+ elif isinstance(arg, torch._C.Value):
+ return arg
+ else:
+ return g.op("Constant", value_z=arg) # type: ignore[attr-defined]
+
+ args = [const_if_tensor(arg) for arg in raw_args]
+ n = g.insertNode(_new_node(g, opname, outputs, *args, **kwargs)) # type: ignore[attr-defined]
+
+ from torch.onnx.symbolic_helper import _onnx_shape_inference
+
+ if _onnx_shape_inference:
+ from torch.onnx.symbolic_helper import (
+ _export_onnx_opset_version as opset_version,
+ )
+
+ torch._C._jit_pass_onnx_node_shape_type_inference(
+ n, _params_dict, opset_version
+ )
+
+ if outputs == 1:
+ return n.output()
+ return tuple(n.outputs())
+
+
+def _block_op(b, opname, *args, **kwargs):
+ if "::" in opname:
+ aten = False
+ ns_opname = opname
+ else:
+ aten = kwargs.pop("aten", False)
+ ns = "aten" if aten else "onnx"
+ ns_opname = ns + "::" + opname
+ n = b.addNode(ns_opname, list(args))
+ for k, v in sorted(kwargs.items()):
+ # TODO: enable inplace in aten exporting mode.
+ if k == "inplace":
+ continue
+ _add_attribute(n, k, v, aten=aten)
+ if len(list(n.outputs())) == 1:
+ return n.output()
+ return tuple(o for o in n.outputs())
+
+
def _add_block(node: torch._C.Node):
return node.addBlock() # type: ignore[attr-defined]
@@ -1297,19 +1459,19 @@
Returns:
The symbolic function if found, None otherwise.
"""
+ import torch.onnx.symbolic_registry as sym_registry
- if not symbolic_registry.is_registered_op(op_name, domain, opset_version):
+ if not sym_registry.is_registered_op(op_name, domain, opset_version):
if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
# Use the original node directly
return None
- return symbolic_registry.get_registered_op(op_name, domain, opset_version)
+ return sym_registry.get_registered_op(op_name, domain, opset_version)
def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
+ import torch.onnx.symbolic_registry as sym_registry
- is_exportable_aten_op = symbolic_registry.is_registered_op(
- op_name, "", opset_version
- )
+ is_exportable_aten_op = sym_registry.is_registered_op(op_name, "", opset_version)
is_onnx_aten_export = (
operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN
)
@@ -1337,10 +1499,11 @@
def _get_aten_op_overload_name(n: torch._C.Node) -> str:
+ from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
schema = n.schema()
- if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
+ if not schema.startswith("aten::") or is_caffe2_aten_fallback():
return ""
return torch._C.parse_schema(schema).overload_name
@@ -1361,9 +1524,11 @@
A single or a tuple of Values.
None when the node gets cloned as is into the new graph.
"""
+ from torch.onnx import symbolic_helper
+ from torch.onnx import symbolic_registry as sym_registry
- opset_version = GLOBALS.export_onnx_opset_version
- symbolic_helper.is_caffe2_aten_fallback = symbolic_helper.is_caffe2_aten_fallback
+ opset_version = symbolic_helper._export_onnx_opset_version
+ is_caffe2_aten_fallback = symbolic_helper.is_caffe2_aten_fallback
# See Note [Export inplace]
# TODO(ezyang): I think this is not necessary anymore
@@ -1374,21 +1539,22 @@
ns, op_name = ns_op_name.split("::")
try:
- symbolic_registry.register_version("", opset_version)
+ sym_registry.register_version("", opset_version)
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
- if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
+ if is_caffe2_aten_fallback() and opset_version == 9:
+ from torch.onnx import symbolic_caffe2
symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
if ns == "aten":
domain = ""
- elif ns == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
+ elif ns == "quantized" and is_caffe2_aten_fallback():
domain = "caffe2"
else:
domain = ns
- if symbolic_registry.is_registered_op(op_name, domain, opset_version):
+ if sym_registry.is_registered_op(op_name, domain, opset_version):
symbolic_fn = _find_symbolic_in_registry(
domain, op_name, opset_version, operator_export_type
)
@@ -1417,15 +1583,13 @@
op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
)
else:
- raise symbolic_registry.UnsupportedOperatorError(
- domain, op_name, opset_version
- )
+ raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version)
except RuntimeError:
if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
return None
elif (
operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
- and not symbolic_helper.is_caffe2_aten_fallback()
+ and not is_caffe2_aten_fallback()
):
# Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
@@ -1440,6 +1604,78 @@
raise
+# Generate an ONNX ATen op node.
+def _aten_op(g, operator, *args, overload_name="", **kwargs):
+ kwargs["aten"] = True
+ return g.op(
+ "ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs
+ )
+
+
+# TODO: We might not need this anymore, since most scalars now show up as tensors
+# TODO(#76254): Remove the helper function if not needed.
+def _graph_constant(
+ g,
+ value,
+ dims,
+ type_: str,
+ *args,
+ **kwargs,
+):
+ """This helper function can create either constant tensor or constant scalar.
+
+ If dims is None or 0 or [0], generate a 0-d tensor (scalar).
+ """
+ assert isinstance(value, numbers.Number)
+ assert type_ is not None
+ isscalar = False
+ if dims is None or dims == 0 or set(dims) == set([0]):
+ dims = [1]
+ isscalar = True
+ type_ = type_.lower()
+ tensor: Union[
+ torch.CharTensor,
+ torch.ShortTensor,
+ torch.IntTensor,
+ torch.LongTensor,
+ torch.HalfTensor,
+ torch.FloatTensor,
+ torch.DoubleTensor,
+ ]
+ if type_ == "char":
+ tensor = torch.CharTensor(*dims)
+ elif type_ == "short":
+ tensor = torch.ShortTensor(*dims)
+ elif type_ == "int":
+ tensor = torch.IntTensor(*dims)
+ elif type_ == "long":
+ tensor = torch.LongTensor(*dims)
+ elif type_ == "half":
+ tensor = torch.HalfTensor(*dims)
+ elif type_ == "float":
+ tensor = torch.FloatTensor(*dims)
+ elif type_ == "double":
+ tensor = torch.DoubleTensor(*dims)
+ else:
+ raise ValueError(
+ "Unknown type, type should be one of the following strings: "
+ "char, short, int, long, half, float, double"
+ )
+ tensor.fill_(value) # type: ignore[call-overload]
+ if isscalar:
+ return g.op("Constant", *args, value_z=tensor, **kwargs)
+ return g.op("Constant", *args, value_t=tensor, **kwargs)
+
+
+def _node_getitem(self, k):
+ """Gets attributes of a node which is polymorphic over return type.
+
+ This is monkey-patched onto Node.
+ """
+ sel = self.kindOf(k)
+ return getattr(self, sel)(k)
+
+
def get_ns_op_name_from_custom_op(symbolic_name):
if not bool(
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
@@ -1472,22 +1708,22 @@
An example of setType is `test_aten_embedding_2` in `test_operators.py`.
"""
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
+ import torch.onnx.symbolic_registry as sym_registry
+ from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
- for version in itertools.chain(
- _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
- ):
+ for version in _onnx_stable_opsets + [_onnx_main_opset]:
if version >= opset_version:
- symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
+ sym_registry.register_op(op_name, symbolic_fn, ns, version)
def unregister_custom_op_symbolic(symbolic_name, opset_version):
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
+ import torch.onnx.symbolic_registry as sym_registry
+ from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
- for version in itertools.chain(
- _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
- ):
+ for version in _onnx_stable_opsets + [_onnx_main_opset]:
if version >= opset_version:
- symbolic_registry.unregister_op(op_name, ns, version)
+ sym_registry.unregister_op(op_name, ns, version)
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
@@ -1538,3 +1774,10 @@
else:
value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
dynamic_axes[key] = value_dict
+
+
+torch._C.Graph.op = _graph_op # type: ignore[attr-defined]
+torch._C.Graph.at = _aten_op # type: ignore[attr-defined]
+torch._C.Block.op = _block_op # type: ignore[attr-defined]
+torch._C.Graph.constant = _graph_constant # type: ignore[attr-defined]
+torch._C.Node.__getitem__ = _node_getitem # type: ignore[attr-defined, misc, assignment]