Revert "[ONNX] Refactor to remove inline imports (#77142)"

This reverts commit c08b8f0967efc2eec078da4541c5fdd003fbdd75.

Reverted https://github.com/pytorch/pytorch/pull/77142 on behalf of https://github.com/malfet
diff --git a/test/onnx/test_onnx_export.py b/test/onnx/test_onnx_export.py
index 6e955d1..52e32c5 100644
--- a/test/onnx/test_onnx_export.py
+++ b/test/onnx/test_onnx_export.py
@@ -5,7 +5,7 @@
 import itertools
 import os
 import sys
-import unittest.mock
+import unittest
 from typing import Callable, Iterable, Optional, Tuple, Union
 
 import onnx
@@ -13,7 +13,6 @@
 
 import torch
 from torch.onnx import OperatorExportTypes, symbolic_registry
-from torch.onnx._globals import GLOBALS
 from torch.onnx.symbolic_helper import _onnx_unsupported
 from torch.testing._internal.common_utils import custom_op, skipIfCaffe2
 
@@ -30,9 +29,9 @@
             Union[contextlib.AbstractContextManager, contextlib.ContextDecorator],
         ]
     ] = None,
-    mocks: Optional[Iterable] = None,
+    mocks: Optional[Iterable[unittest.mock.patch]] = None,
     operator_export_type: OperatorExportTypes = OperatorExportTypes.ONNX,
-    opset_version: int = GLOBALS.export_onnx_opset_version,
+    opset_version: int = torch.onnx.symbolic_helper._export_onnx_opset_version,
 ) -> onnx.ModelProto:
     """Exports `model(input)` to ONNX and returns it.
 
diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py
index cd672ac..549666d 100644
--- a/test/onnx/test_onnx_opset.py
+++ b/test/onnx/test_onnx_opset.py
@@ -10,12 +10,10 @@
 import torch.onnx
 from torch.nn import Module
 from torch.onnx import producer_name, producer_version
-from torch.onnx._globals import GLOBALS
+from torch.onnx.symbolic_helper import _export_onnx_opset_version
 
 
-def check_onnx_opset_operator(
-    model, ops, opset_version=GLOBALS.export_onnx_opset_version
-):
+def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_version):
     # check_onnx_components
     assert (
         model.producer_name == producer_name
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
index 38ac87d..c6279e1 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py
@@ -9,18 +9,16 @@
     skipIfUnsupportedMinOpsetVersion,
     skipScriptTest,
 )
-
-# TODO(justinchuby): Remove reference to other unit tests.
 from test_pytorch_onnx_onnxruntime import TestONNXRuntime
 
 import torch
 from torch.cuda.amp import autocast
-from torch.onnx._globals import GLOBALS
 
 
 class TestONNXRuntime_cuda(unittest.TestCase):
+    from torch.onnx.symbolic_helper import _export_onnx_opset_version
 
-    opset_version = GLOBALS.export_onnx_opset_version
+    opset_version = _export_onnx_opset_version
     keep_initializers_as_inputs = True
     onnx_shape_inference = True
 
diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py
index 50e40ae..9e354fb 100644
--- a/test/onnx/test_pytorch_onnx_shape_inference.py
+++ b/test/onnx/test_pytorch_onnx_shape_inference.py
@@ -6,8 +6,11 @@
 from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
 
 import torch
-from torch.onnx import _constants
-from torch.onnx.symbolic_helper import _set_onnx_shape_inference, _set_opset_version
+from torch.onnx.symbolic_helper import (
+    _onnx_main_opset,
+    _set_onnx_shape_inference,
+    _set_opset_version,
+)
 
 
 def expect_tensor(scalar_type, shape=None):
@@ -24,7 +27,7 @@
 class TestONNXShapeInference(unittest.TestCase):
     def __init__(self, *args, **kwargs):
         unittest.TestCase.__init__(self, *args, **kwargs)
-        self.opset_version = _constants.onnx_main_opset
+        self.opset_version = _onnx_main_opset
         _set_onnx_shape_inference(True)
         _set_opset_version(self.opset_version)
 
diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
index 9638e53..1c39140 100644
--- a/test/onnx/test_utility_funs.py
+++ b/test/onnx/test_utility_funs.py
@@ -2,6 +2,7 @@
 
 import copy
 import io
+import unittest
 
 import onnx
 import torchvision
@@ -33,6 +34,8 @@
     parse_args,
 )
 
+skip = unittest.skip
+
 
 class _BaseTestCase(TestCase):
     def setUp(self):
diff --git a/tools/onnx/update_default_opset_version.py b/tools/onnx/update_default_opset_version.py
index dfdbf1f..c7ecd59 100755
--- a/tools/onnx/update_default_opset_version.py
+++ b/tools/onnx/update_default_opset_version.py
@@ -78,8 +78,8 @@
 
 
 read_sub_write(
-    os.path.join("torch", "onnx", "_constants.py"),
-    r"(onnx_default_opset = )\d+",
+    os.path.join("torch", "onnx", "symbolic_helper.py"),
+    r"(_default_onnx_opset_version = )\d+",
 )
 read_sub_write(
     os.path.join("torch", "onnx", "__init__.py"), r"(opset_version \(int, default )\d+"
diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp
index 150eee9..2227842 100644
--- a/torch/csrc/jit/passes/onnx.cpp
+++ b/torch/csrc/jit/passes/onnx.cpp
@@ -242,7 +242,7 @@
     ::torch::onnx::OperatorExportTypes operator_export_type,
     std::unordered_map<Value*, Value*>& env) {
   py::object onnx = py::module::import("torch.onnx");
-  py::object onnx_globals = py::module::import("torch.onnx._globals");
+  py::object onnx_symbolic = py::module::import("torch.onnx.symbolic_helper");
   py::object onnx_registry = py::module::import("torch.onnx.symbolic_registry");
 
   // Setup all the lambda helper functions.
@@ -273,8 +273,8 @@
     }
     // For const node, it does not need params_dict info, so set it to {}.
     const ParamMap empty_params_dict = {};
-    auto opset_version = py::cast<int>(
-        onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"));
+    auto opset_version =
+        py::cast<int>(onnx_symbolic.attr("_export_onnx_opset_version"));
     for (const auto i : c10::irange(num_old_outputs)) {
       auto old = old_outputs[i];
       if (outputs[i]) {
@@ -435,8 +435,7 @@
       pyobj = func->get();
     }
 
-    py::object opset_version =
-        onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
+    py::object opset_version = onnx_symbolic.attr("_export_onnx_opset_version");
     py::object is_registered_op = onnx_registry.attr("is_registered_op")(
         "PythonOp", "prim", opset_version);
     if (!py::hasattr(pyobj, "symbolic") &&
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 64bf14c..06541b4 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -1,4 +1,4 @@
-from typing import Dict
+from typing import Dict, Optional
 
 import torch._C as _C
 
diff --git a/torch/onnx/_constants.py b/torch/onnx/_constants.py
deleted file mode 100644
index 0a7e853..0000000
--- a/torch/onnx/_constants.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""Constant values used in ONNX."""
-
-onnx_default_opset = 13
-onnx_main_opset = 16
-onnx_stable_opsets = tuple(range(7, onnx_main_opset))
-onnx_constant_folding_opsets = tuple(range(9, onnx_main_opset + 1))
diff --git a/torch/onnx/_globals.py b/torch/onnx/_globals.py
deleted file mode 100644
index 67d478e..0000000
--- a/torch/onnx/_globals.py
+++ /dev/null
@@ -1,48 +0,0 @@
-"""Globals used internally by the ONNX exporter.
-
-Do not use this module outside of `torch.onnx` and its tests.
-
-Be very judicious when adding any new global variables. Do not create new global
-variables unless they are absolutely necessary.
-"""
-from __future__ import annotations
-
-import typing
-from typing import Optional
-
-# This module should only depend on _constants and nothing else in torch.onnx to keep
-# dependency direction clean.
-from torch.onnx import _constants
-
-if typing.TYPE_CHECKING:
-    # Postpone type checking to avoid circular dependencies.
-    from torch.onnx import OperatorExportTypes, TrainingMode
-
-
-class _InternalGlobals:
-    """Globals used internally by ONNX exporter.
-
-    NOTE: Be very judicious when adding any new variables. Do not create new
-    global variables unless they are absolutely necessary.
-    """
-
-    def __init__(self):
-        self._export_onnx_opset_version = _constants.onnx_default_opset
-        self.operator_export_type: Optional[OperatorExportTypes] = None
-        self.training_mode: Optional[TrainingMode] = None
-        self.onnx_shape_inference: bool = False
-
-    @property
-    def export_onnx_opset_version(self):
-        return self._export_onnx_opset_version
-
-    @export_onnx_opset_version.setter
-    def export_onnx_opset_version(self, value: int):
-        supported_versions = [_constants.onnx_main_opset]
-        supported_versions.extend(_constants.onnx_stable_opsets)
-        if value not in supported_versions:
-            raise ValueError(f"Unsupported ONNX opset version: {value}")
-        self._export_onnx_opset_version = value
-
-
-GLOBALS = _InternalGlobals()
diff --git a/torch/onnx/_patch_torch.py b/torch/onnx/_patch_torch.py
deleted file mode 100644
index 01af5fe..0000000
--- a/torch/onnx/_patch_torch.py
+++ /dev/null
@@ -1,238 +0,0 @@
-"""Importing this patches torch._C classes to add ONNX conveniences."""
-import numbers
-import re
-from typing import Iterable, Tuple, Union
-
-import torch
-from torch.onnx._globals import GLOBALS
-
-
-def _graph_op(
-    g: torch._C.Graph,
-    opname: str,
-    *raw_args: torch._C.Node,
-    outputs: int = 1,
-    **kwargs,
-) -> Union[torch._C.Value, Tuple[torch._C.Value, ...]]:
-    r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
-
-    The set of operators and the inputs/attributes they take
-    is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
-
-    This function is monkey-patched onto Graph.
-
-    Args:
-        g: The Torch graph.
-        opname: The ONNX operator name, e.g., `Abs` or `Add`. TODO(justinchu): Update examples to correct ones.
-        raw_args: The inputs to the operator; usually provided
-            as arguments to the `symbolic` definition.
-        outputs: The number of outputs this operator returns.
-            By default an operator is assumed to return a single output.
-            If `outputs` is greater than one, this functions returns a tuple
-            of output `Node`, representing each output of the ONNX operator
-            in positional.
-        kwargs: The attributes of the ONNX operator, whose keys are named
-            according to the following convention: `alpha_f` indicates
-            the `alpha` attribute with type `f`.  The valid type specifiers are
-            `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
-            specified with type float accepts either a single float, or a
-            list of floats (e.g., you would say `dims_i` for a `dims` attribute
-            that takes a list of integers).
-
-    Returns:
-        The node representing the single output of this operator (see the `outputs`
-        keyword argument for multi-return nodes).
-    """
-    # Filter out None attributes, this can be convenient client side because
-    # now they can pass through None attributes, and have them not show up
-    kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
-
-    def const_if_tensor(arg):
-        if arg is None:
-            return arg
-        elif isinstance(arg, torch._C.Value):
-            return arg
-        else:
-            return g.op("Constant", value_z=arg)  # type: ignore[attr-defined]
-
-    args = [const_if_tensor(arg) for arg in raw_args]
-    n = g.insertNode(_new_node(g, opname, outputs, *args, **kwargs))  # type: ignore[attr-defined]
-
-    # Import utils to get _params_dict because it is a global that is accessed by c++ code
-    from torch.onnx import utils
-
-    if GLOBALS.onnx_shape_inference:
-        torch._C._jit_pass_onnx_node_shape_type_inference(
-            n, utils._params_dict, GLOBALS.export_onnx_opset_version
-        )
-
-    if outputs == 1:
-        return n.output()
-    return tuple(n.outputs())
-
-
-# Generate an ONNX ATen op node.
-def _aten_op(g, operator, *args, overload_name="", **kwargs):
-    kwargs["aten"] = True
-    return g.op(
-        "ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs
-    )
-
-
-def _block_op(b, opname, *args, **kwargs):
-    if "::" in opname:
-        aten = False
-        ns_opname = opname
-    else:
-        aten = kwargs.pop("aten", False)
-        ns = "aten" if aten else "onnx"
-        ns_opname = ns + "::" + opname
-    n = b.addNode(ns_opname, list(args))
-    for k, v in sorted(kwargs.items()):
-        # TODO: enable inplace in aten exporting mode.
-        if k == "inplace":
-            continue
-        _add_attribute(n, k, v, aten=aten)
-    if len(list(n.outputs())) == 1:
-        return n.output()
-    return tuple(o for o in n.outputs())
-
-
-def _new_node(g: torch._C.Graph, opname: str, outputs, *args, **kwargs):
-    if "::" in opname:
-        aten = False
-        ns_opname = opname
-    else:
-        aten = kwargs.pop("aten", False)
-        ns = "aten" if aten else "onnx"
-        ns_opname = ns + "::" + opname
-    n = g.create(ns_opname, args, outputs)  # type: ignore[attr-defined]
-    for k, v in sorted(kwargs.items()):
-        # TODO: enable inplace in aten exporting mode.
-        if k == "inplace":
-            continue
-        _add_attribute(n, k, v, aten=aten)
-    return n
-
-
-_attr_pattern = re.compile("^(.+)_(([ifstgz])|(ty))$")
-
-
-def _is_onnx_list(value):
-    return (
-        not isinstance(value, torch._six.string_classes)
-        and not isinstance(value, torch.Tensor)
-        and isinstance(value, Iterable)
-    )
-
-
-def _scalar(x):
-    """Convert a scalar tensor into a Python value."""
-    assert x.numel() == 1
-    return x[0]
-
-
-def _is_caffe2_aten_fallback():
-    return (
-        GLOBALS.operator_export_type
-        == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
-        and torch.onnx._CAFFE2_ATEN_FALLBACK
-    )
-
-
-def _add_attribute(node, key, value, aten):
-    r"""Initializes the right attribute based on type of value."""
-    m = _attr_pattern.match(key)
-    if m is None:
-        raise IndexError(
-            (
-                "Invalid attribute specifier '{}' names "
-                + " must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
-            ).format(key)
-        )
-    name, kind = m.group(1), m.group(2)
-    if _is_onnx_list(value):
-        kind += "s"
-
-    if aten and _is_caffe2_aten_fallback():
-        if isinstance(value, torch.Tensor):
-            # Caffe2 proto does not support tensor attribute.
-            if value.numel() > 1:
-                raise ValueError("Should not pass tensor attribute")
-            value = _scalar(value)
-            if isinstance(value, float):
-                kind = "f"
-            else:
-                kind = "i"
-    return getattr(node, kind + "_")(name, value)
-
-
-# TODO: We might not need this anymore, since most scalars now show up as tensors
-# TODO(#76254): Remove the helper function if not needed.
-def _graph_constant(
-    g,
-    value,
-    dims,
-    type_: str,
-    *args,
-    **kwargs,
-):
-    """This helper function can create either constant tensor or constant scalar.
-
-    If dims is None or 0 or [0], generate a 0-d tensor (scalar).
-    """
-    assert isinstance(value, numbers.Number)
-    assert type_ is not None
-    isscalar = False
-    if dims is None or dims == 0 or set(dims) == set([0]):
-        dims = [1]
-        isscalar = True
-    type_ = type_.lower()
-    tensor: Union[
-        torch.CharTensor,
-        torch.ShortTensor,
-        torch.IntTensor,
-        torch.LongTensor,
-        torch.HalfTensor,
-        torch.FloatTensor,
-        torch.DoubleTensor,
-    ]
-    if type_ == "char":
-        tensor = torch.CharTensor(*dims)
-    elif type_ == "short":
-        tensor = torch.ShortTensor(*dims)
-    elif type_ == "int":
-        tensor = torch.IntTensor(*dims)
-    elif type_ == "long":
-        tensor = torch.LongTensor(*dims)
-    elif type_ == "half":
-        tensor = torch.HalfTensor(*dims)
-    elif type_ == "float":
-        tensor = torch.FloatTensor(*dims)
-    elif type_ == "double":
-        tensor = torch.DoubleTensor(*dims)
-    else:
-        raise ValueError(
-            "Unknown type, type should be one of the following strings: "
-            "char, short, int, long, half, float, double"
-        )
-    tensor.fill_(value)  # type: ignore[call-overload]
-    if isscalar:
-        return g.op("Constant", *args, value_z=tensor, **kwargs)
-    return g.op("Constant", *args, value_t=tensor, **kwargs)
-
-
-def _node_getitem(self, k):
-    """Gets attributes of a node which is polymorphic over return type.
-
-    This is monkey-patched onto Node.
-    """
-    sel = self.kindOf(k)
-    return getattr(self, sel)(k)
-
-
-torch._C.Graph.op = _graph_op  # type: ignore[attr-defined]
-torch._C.Graph.at = _aten_op  # type: ignore[attr-defined]
-torch._C.Block.op = _block_op  # type: ignore[attr-defined]
-torch._C.Graph.constant = _graph_constant  # type: ignore[attr-defined]
-torch._C.Node.__getitem__ = _node_getitem  # type: ignore[attr-defined, misc, assignment]
diff --git a/torch/onnx/onnx_supported_ops.py b/torch/onnx/onnx_supported_ops.py
index eacc77d..bc85029 100644
--- a/torch/onnx/onnx_supported_ops.py
+++ b/torch/onnx/onnx_supported_ops.py
@@ -2,11 +2,10 @@
 from typing import Dict, List, Union
 
 import torch._C
-from torch.onnx import _constants, symbolic_registry
+from torch.onnx import symbolic_helper, symbolic_registry
 
-for v in _constants.onnx_stable_opsets:
+for v in symbolic_helper._onnx_stable_opsets + [symbolic_helper._onnx_main_opset]:
     symbolic_registry.register_version("", v)
-symbolic_registry.register_version("", _constants.onnx_main_opset)
 
 
 class _TorchSchema:
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 59cd61e..f1a588a 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -7,12 +7,11 @@
 
 import torch
 import torch.onnx
-from torch._C import OptionalType
 
 # This import monkey-patches graph manipulation methods on Graph, used for the
 # ONNX symbolics
-from torch.onnx import _patch_torch  # noqa: F401
-from torch.onnx._globals import GLOBALS
+import torch.onnx.utils
+from torch._C import OptionalType
 
 # Note [Edit Symbolic Files]
 # EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
@@ -166,27 +165,24 @@
     """A decorator which converts args from torch._C.Value to built-in types.
 
     For example:
-
-    ```
     @parse_args('v', 'i', 'fs')
     foo(g, a, b, c):
-        assert isinstance(a, torch._C.Value)
-        assert isinstance(b, int)
-        assert isinstance(c, list)
-        assert isinstance(c[0], float)
-    ```
+      assert isinstance(a, torch._C.Value)
+      assert isinstance(b, int)
+      assert isinstance(c, list)
+      assert isinstance(c[0], float)
 
     Args:
-        arg_descriptors: list of str, where each element is
-            a string that specifies the type to convert to. Valid descriptors:
-            "v": no conversion, keep torch._C.Value.
-            "i": int
-            "is": list of int
-            "f": float
-            "fs": list of float
-            "b": bool
-            "s": str
-            "t": torch.Tensor
+      arg_descriptors: list of str, where each element is
+        a string that specifies the type to convert to. Valid descriptors:
+        "v": no conversion, keep torch._C.Value.
+        "i": int
+        "is": list(int)
+        "f": float
+        "fs": list of float
+        "b": bool
+        "s": str
+        "t": torch.Tensor
     """
 
     def decorator(fn):
@@ -317,7 +313,7 @@
     return x.item()
 
 
-def _if_scalar_type_as(g: torch._C.Graph, self, tensor):
+def _if_scalar_type_as(g, self, tensor):
     """
     Convert self into the same type of tensor, as necessary.
 
@@ -380,8 +376,7 @@
 
 def is_caffe2_aten_fallback():
     return (
-        GLOBALS.operator_export_type
-        == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
+        _operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
         and torch.onnx._CAFFE2_ATEN_FALLBACK
     )
 
@@ -431,7 +426,7 @@
         warnings.warn(
             "ONNX export failed on " + op + " because " + msg + " not supported"
         )
-    elif GLOBALS.operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
+    elif _operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
         _onnx_unsupported(f"{op}, {msg}")
 
 
@@ -467,7 +462,7 @@
         raise RuntimeError(
             "ONNX export failed on {}, which is not implemented for opset {}. "
             "Try exporting with other opset versions.".format(
-                name, GLOBALS.export_onnx_opset_version
+                name, _export_onnx_opset_version
             )
         )
 
@@ -503,7 +498,7 @@
 
 
 def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
-    if GLOBALS.export_onnx_opset_version <= 9:
+    if _export_onnx_opset_version <= 9:
         from torch.onnx.symbolic_opset9 import _slice as _slice9
 
         return _slice9(g, input, axes, starts, ends)
@@ -559,7 +554,7 @@
         shape_,
         g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)),
     )
-    if GLOBALS.export_onnx_opset_version <= 10:
+    if _export_onnx_opset_version <= 10:
         if not decending:
             _unimplemented("Sort", "Ascending is not supported")
         return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2)
@@ -578,7 +573,7 @@
         k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1])))
         if _try_get_scalar_type(k) != "Long":
             k = g.op("Cast", k, to_i=torch.onnx.TensorProtoDataType.INT64)
-    if GLOBALS.export_onnx_opset_version <= 10:
+    if _export_onnx_opset_version <= 10:
         if not largest:
             _unimplemented("TopK", "Ascending is not supported")
         return g.op("TopK", input, k, axis_i=dim, outputs=2)
@@ -589,7 +584,7 @@
 
 
 def _lt_helper(g, input, other):
-    if GLOBALS.export_onnx_opset_version <= 8:
+    if _export_onnx_opset_version <= 8:
         from torch.onnx.symbolic_opset8 import lt as _lt8
 
         return _lt8(g, input, other)
@@ -600,14 +595,12 @@
 
 
 def _interpolate_warning(interpolate_mode):
-    onnx_op = (
-        "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
-    )
+    onnx_op = "onnx:Resize" if _export_onnx_opset_version >= 10 else "onnx:Upsample"
     warnings.warn(
         "You are trying to export the model with "
         + onnx_op
         + " for ONNX opset version "
-        "" + str(GLOBALS.export_onnx_opset_version) + ". "
+        "" + str(_export_onnx_opset_version) + ". "
         "This operator might cause results to not match the expected results by PyTorch.\n"
         "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. "
         "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 "
@@ -618,12 +611,12 @@
 
 def _unsqueeze_helper(g, input, axes_i):
     if _is_constant(axes_i[0]):
-        if GLOBALS.export_onnx_opset_version >= 13:
+        if _export_onnx_opset_version >= 13:
             axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
             return g.op("Unsqueeze", input, axes)
         return g.op("Unsqueeze", input, axes_i=axes_i)
     # Tensor type
-    if GLOBALS.export_onnx_opset_version < 13:
+    if _export_onnx_opset_version < 13:
         raise ValueError(
             f"Opset version must be >= 13 for Unsqueeze with dynamic axes. {input.node().sourceRange()}"
         )
@@ -632,12 +625,12 @@
 
 def _squeeze_helper(g, input, axes_i):
     if _is_constant(axes_i[0]):
-        if GLOBALS.export_onnx_opset_version >= 13:
+        if _export_onnx_opset_version >= 13:
             axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
             return g.op("Squeeze", input, axes)
         return g.op("Squeeze", input, axes_i=axes_i)
     # Tensor type
-    if GLOBALS.export_onnx_opset_version < 13:
+    if _export_onnx_opset_version < 13:
         raise ValueError(
             f"Opset version must be >= 13 for Squeeze with dynamic axes. {input.node().sourceRange()}"
         )
@@ -656,7 +649,7 @@
 
 def _reducesum_helper(g, input, axes_i=None, keepdims_i=1, noop_with_empty_axes_i=0):
     keepdims_i = _maybe_get_const(keepdims_i, "i")
-    if GLOBALS.export_onnx_opset_version >= 13:
+    if _export_onnx_opset_version >= 13:
         if axes_i:
             if not _is_value(axes_i):
                 axes_i = g.op(
@@ -800,7 +793,7 @@
             output_size = g.op("Cast", output_size, to_i=cast_pytorch_to_onnx["Long"])
             output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
 
-            if GLOBALS.export_onnx_opset_version >= 13:
+            if _export_onnx_opset_version >= 13:
                 empty_roi = _optional_input_placeholder_tensor(g)
                 empty_scales = _optional_input_placeholder_tensor(g)
             else:
@@ -823,7 +816,7 @@
                 nearest_mode_s="floor",
             )  # only valid when mode="nearest"
         else:
-            if GLOBALS.export_onnx_opset_version >= 13:
+            if _export_onnx_opset_version >= 13:
                 empty_roi = _optional_input_placeholder_tensor(g)
             else:
                 empty_roi = g.op(
@@ -894,7 +887,7 @@
         size = g.op("Cast", size, to_i=cast_pytorch_to_onnx["Long"])
         size = g.op("Concat", input_size, size, axis_i=0)
 
-        if GLOBALS.export_onnx_opset_version >= 13:
+        if _export_onnx_opset_version >= 13:
             empty_roi = _optional_input_placeholder_tensor(g)
             empty_scales = _optional_input_placeholder_tensor(g)
         else:
@@ -919,7 +912,7 @@
         if rank is None:
             return _unimplemented("interpolate (with scales)", "missing input shape")
 
-        if GLOBALS.export_onnx_opset_version >= 13:
+        if _export_onnx_opset_version >= 13:
             empty_roi = _optional_input_placeholder_tensor(g)
         else:
             empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
@@ -938,9 +931,9 @@
 
 
 def _unbind_helper(g, self, dim, _outputs):
-    if GLOBALS.export_onnx_opset_version < 11:
+    if _export_onnx_opset_version < 11:
         from torch.onnx.symbolic_opset9 import unbind
-    elif GLOBALS.export_onnx_opset_version <= 12:
+    elif _export_onnx_opset_version <= 12:
         from torch.onnx.symbolic_opset11 import unbind  # type: ignore[no-redef]
     else:
         from torch.onnx.symbolic_opset13 import unbind  # type: ignore[no-redef]
@@ -948,7 +941,7 @@
 
 
 def _scatter_helper(g, self, dim, index, src):
-    if GLOBALS.export_onnx_opset_version <= 10:
+    if _export_onnx_opset_version <= 10:
         from torch.onnx.symbolic_opset9 import scatter
     else:
         # for mypy, scatter was imported two lines above
@@ -957,7 +950,7 @@
 
 
 def _repeat_interleave_split_helper(g, self, reps, dim):
-    if GLOBALS.export_onnx_opset_version <= 12:
+    if _export_onnx_opset_version <= 12:
         split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
     else:
         from torch.onnx.symbolic_opset13 import split
@@ -996,7 +989,7 @@
 
 
 def _arange_helper(g, *args):
-    if GLOBALS.export_onnx_opset_version <= 10:
+    if _export_onnx_opset_version <= 10:
         from torch.onnx.symbolic_opset9 import arange
     else:
         from torch.onnx.symbolic_opset11 import arange  # type: ignore[no-redef]
@@ -1018,7 +1011,7 @@
 
     from torch.onnx.symbolic_opset9 import expand
 
-    if GLOBALS.export_onnx_opset_version <= 10:
+    if _export_onnx_opset_version <= 10:
         from torch.onnx.symbolic_opset9 import scatter
     else:
         # for mypy, scatter was imported two lines above
@@ -1047,10 +1040,10 @@
     shape = _maybe_get_const(shape, "is")
     if not _is_value(shape):
         shape = g.op("Constant", value_t=torch.LongTensor(shape))
-    if GLOBALS.export_onnx_opset_version <= 13:
+    if _export_onnx_opset_version <= 13:
         if allowzero == 1:
             raise _onnx_opset_unsupported(
-                "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14
+                "Reshape with allowzero=1", _export_onnx_opset_version, 14
             )
         return g.op("Reshape", input, shape)
     else:
@@ -1118,11 +1111,12 @@
 
 
 def check_training_mode(op_train_mode, op_name):
+    global _training_mode
     op_train_mode = True if op_train_mode == 1 else False
-    if GLOBALS.training_mode is not None and op_train_mode != GLOBALS.training_mode:
+    if _training_mode is not None and op_train_mode != _training_mode:
         op_mode = "training " if op_train_mode else "inference"
-        training_mode = "training " if GLOBALS.training_mode else "inference"
-        # setting the model mode could result in op_mode != _flags.training_mode
+        training_mode = "training " if _training_mode else "inference"
+        # setting the model mode could result in op_mode != _training_mode
         # if the model is a FuncModule. In this case we warn the user of
         # the state and export depending on op_mode
         # This is to support use-cases of fixing certain layer weights
@@ -1213,10 +1207,10 @@
     scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
     zero_point = g.op("Cast", zero_point, to_i=qdtype)
 
-    if axis_i is not None and GLOBALS.export_onnx_opset_version < 13:
+    if axis_i is not None and _export_onnx_opset_version < 13:
         _onnx_opset_unsupported_detailed(
             "DequantizeLinear",
-            GLOBALS.export_onnx_opset_version,
+            _export_onnx_opset_version,
             13,
             "Attribute axis is not supported.",
         )
@@ -1238,16 +1232,12 @@
         scale: torch._C.Value, quantized scale.
         zero_point: torch._C.Value, quantized zero point.
         axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization.
-            Otherwise, represents per channel quantization, along given axis.
+          Otherwise, represents per channel quantization, along given axis.
     """
-    if (
-        axis is not None
-        and not _is_none(axis)
-        and GLOBALS.export_onnx_opset_version < 13
-    ):
+    if axis is not None and not _is_none(axis) and _export_onnx_opset_version < 13:
         _onnx_opset_unsupported_detailed(
             "QuantizeLinear",
-            GLOBALS.export_onnx_opset_version,
+            _export_onnx_opset_version,
             13,
             "Attribute axis is not supported.",
         )
@@ -1299,23 +1289,43 @@
     return has_same_dtype
 
 
-# TODO(justinchuby): Delete these setters, users should set the vars directly.
-def _set_opset_version(opset_version: int):
-    GLOBALS.export_onnx_opset_version = opset_version
+_default_onnx_opset_version = 13
+_onnx_main_opset = 16
+_onnx_stable_opsets = list(range(7, _onnx_main_opset))
+_export_onnx_opset_version = _default_onnx_opset_version
+_constant_folding_opset_versions = list(range(9, _onnx_main_opset + 1))
+
+
+def _set_opset_version(opset_version):
+    global _export_onnx_opset_version
+    if opset_version in _onnx_stable_opsets + [_onnx_main_opset]:
+        _export_onnx_opset_version = opset_version
+        return
+    raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
+
+
+_operator_export_type = None
 
 
 def _set_operator_export_type(operator_export_type):
-    GLOBALS.operator_export_type = operator_export_type
+    global _operator_export_type
+    _operator_export_type = operator_export_type
+
+
+_training_mode = None
 
 
 def _set_training_mode(training_mode):
-    GLOBALS.training_mode = training_mode
+    global _training_mode
+    _training_mode = training_mode
 
 
+_onnx_shape_inference = False
 # This function is for debug use only.
-# onnx_shape_inference = False by default.
-def _set_onnx_shape_inference(onnx_shape_inference: bool):
-    GLOBALS.onnx_shape_inference = onnx_shape_inference
+# onnx_shape_inference = True by default.
+def _set_onnx_shape_inference(onnx_shape_inference):
+    global _onnx_shape_inference
+    _onnx_shape_inference = onnx_shape_inference
 
 
 # Metaprogram symbolics for each ATen native specialized cast operator.
diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py
index fe8e099..53ec445 100644
--- a/torch/onnx/symbolic_opset10.py
+++ b/torch/onnx/symbolic_opset10.py
@@ -1,17 +1,17 @@
-import sys
-import warnings
+# -*- coding: utf-8 -*-
+
+from sys import maxsize
 
 import torch
 import torch.onnx
 import torch.onnx.symbolic_helper as sym_help
 import torch.onnx.symbolic_opset9
-from torch.nn.modules.utils import _pair, _single, _triple
 
 # This import monkey-patches graph manipulation methods on Graph, used for the
 # ONNX symbolics
-from torch.onnx import _patch_torch  # noqa: F401
-from torch.onnx._globals import GLOBALS
-from torch.onnx.symbolic_helper import parse_args, quantized_args
+import torch.onnx.utils
+from torch.nn.modules.utils import _pair, _single, _triple
+from torch.onnx.symbolic_helper import _unimplemented, parse_args, quantized_args
 from torch.onnx.symbolic_opset9 import (
     add,
     conv2d,
@@ -202,7 +202,7 @@
         sym_help._interpolate_warning(interpolate_mode)
         align_corners = sym_help._maybe_get_scalar(align_corners)
         if align_corners:
-            return sym_help._unimplemented(name, "align_corners == True")
+            return _unimplemented(name, "align_corners == True")
         if scales is None:
             scales = sym_help._interpolate_size_to_scales(g, input, output_size, dim)
         return g.op("Resize", input, scales, mode_s=interpolate_mode)
@@ -328,12 +328,13 @@
     include_last_offset,
     padding_idx,
 ):
-    if scale_grad_by_freq and GLOBALS.training_mode:
+    if scale_grad_by_freq and sym_help._training_mode:
         return sym_help._onnx_unsupported(
             "embedding_bag with scale_grad_by_freq for training mode"
         )
     if padding_idx is not None and padding_idx >= 0:
         raise RuntimeError("embedding_bag with padding_idx")
+    import warnings
 
     from torch.onnx.symbolic_opset9 import select
 
@@ -350,7 +351,7 @@
             offset_len = offsets_dim_0
             offsets_extended = [
                 offsets,
-                g.op("Constant", value_t=torch.tensor([sys.maxsize])),
+                g.op("Constant", value_t=torch.tensor([maxsize])),
             ]
             offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
         list_ = []
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index b3dc019..6d96135 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -1,11 +1,12 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
 import warnings
 from sys import maxsize
 
 import torch
 import torch.onnx.symbolic_helper as sym_help
-import torch.onnx.utils
 from torch.nn.modules.utils import _pair, _single, _triple
-from torch.onnx._globals import GLOBALS
 from torch.onnx.symbolic_helper import (
     ScalarType,
     _is_tensor_list,
@@ -16,6 +17,11 @@
 from torch.onnx.symbolic_opset9 import _pad_circular, expand
 from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn
 from torch.onnx.symbolic_opset9 import mul, op_with_optional_float_cast, unused
+from torch.onnx.utils import (
+    _add_block,
+    _add_input_to_block,
+    _add_output_to_block,
+)
 
 # EDITING THIS FILE? READ THIS FIRST!
 # see Note [Edit Symbolic Files] in symbolic_helper.py
@@ -1072,7 +1078,7 @@
     include_last_offset,
     padding_idx,
 ):
-    if scale_grad_by_freq and GLOBALS.training_mode:
+    if scale_grad_by_freq and sym_help._training_mode:
         return sym_help._onnx_unsupported(
             "embedding_bag with scale_grad_by_freq for training mode"
         )
@@ -1107,9 +1113,9 @@
     )
     loop = g.op("Loop", loop_len, loop_condition)
 
-    loop_block = torch.onnx.utils._add_block(loop.node())
-    block_input_iter = torch.onnx.utils._add_input_to_block(loop_block)
-    cond = torch.onnx.utils._add_input_to_block(loop_block)
+    loop_block = _add_block(loop.node())
+    block_input_iter = _add_input_to_block(loop_block)
+    cond = _add_input_to_block(loop_block)
 
     indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0)
     indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0)
@@ -1136,8 +1142,8 @@
         embeddings = loop_block.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
 
     cond_out = loop_block.op("Cast", loop_condition, to_i=9)
-    torch.onnx.utils._add_output_to_block(loop_block, cond_out)
-    torch.onnx.utils._add_output_to_block(loop_block, embeddings)
+    _add_output_to_block(loop_block, cond_out)
+    _add_output_to_block(loop_block, embeddings)
 
     # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
     # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py
index df8e6bc..3825eb2 100644
--- a/torch/onnx/symbolic_opset12.py
+++ b/torch/onnx/symbolic_opset12.py
@@ -3,9 +3,13 @@
 
 import torch
 import torch.onnx.symbolic_helper as sym_help
-import torch.onnx.utils
 from torch.onnx.symbolic_helper import _parse_arg, _unimplemented, parse_args
 from torch.onnx.symbolic_opset9 import _reshape_from_tensor, permute
+from torch.onnx.utils import (
+    _add_block,
+    _add_input_to_block,
+    _add_output_to_block,
+)
 
 # EDITING THIS FILE? READ THIS FIRST!
 # see Note [Edit Symbolic Files] in symbolic_helper.py
@@ -270,9 +274,9 @@
         loop_len = g.op("Min", low_size, hi_size)
         loop = g.op("Loop", loop_len, loop_condition)
 
-        loop_block = torch.onnx.utils._add_block(loop.node())
-        block_input_iter = torch.onnx.utils._add_input_to_block(loop_block)
-        cond = torch.onnx.utils._add_input_to_block(loop_block)
+        loop_block = _add_block(loop.node())
+        block_input_iter = _add_input_to_block(loop_block)
+        cond = _add_input_to_block(loop_block)
 
         starts = loop_block.op("Gather", low_indices, block_input_iter)
         ends = loop_block.op("Gather", hi_indices, block_input_iter)
@@ -288,8 +292,8 @@
         concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0)
 
         cond_out = loop_block.op("Cast", loop_condition, to_i=9)
-        torch.onnx.utils._add_output_to_block(loop_block, cond_out)
-        torch.onnx.utils._add_output_to_block(loop_block, concat)
+        _add_output_to_block(loop_block, cond_out)
+        _add_output_to_block(loop_block, concat)
 
         loop_output = loop.node().output()
         perm = [0, 1, 2, 3, 4]
diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py
index 6167e40..6de4f86 100644
--- a/torch/onnx/symbolic_opset13.py
+++ b/torch/onnx/symbolic_opset13.py
@@ -4,7 +4,6 @@
 # This file exports ONNX ops for opset 13
 import torch
 import torch.onnx.symbolic_helper as sym_help
-import torch.onnx.utils
 from torch.onnx.symbolic_helper import _unimplemented, parse_args
 from torch.onnx.symbolic_opset9 import (
     _maybe_cast_reduce_op_input,
@@ -20,6 +19,11 @@
     zeros,
 )
 from torch.onnx.symbolic_opset11 import unsqueeze
+from torch.onnx.utils import (
+    _add_block,
+    _add_input_to_block,
+    _add_output_to_block,
+)
 
 # EDITING THIS FILE? READ THIS FIRST!
 # see Note [Edit Symbolic Files] in symbolic_helper.py
@@ -387,10 +391,10 @@
     loop = g.op("Loop", loop_len, loop_condition, final_splits)
 
     # Loop inputs
-    loop_block = torch.onnx.utils._add_block(loop.node())
-    block_input_iter = torch.onnx.utils._add_input_to_block(loop_block)
-    cond = torch.onnx.utils._add_input_to_block(loop_block)
-    final_splits = torch.onnx.utils._add_input_to_block(loop_block)
+    loop_block = _add_block(loop.node())
+    block_input_iter = _add_input_to_block(loop_block)
+    cond = _add_input_to_block(loop_block)
+    final_splits = _add_input_to_block(loop_block)
 
     r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
     i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)
@@ -410,8 +414,8 @@
 
     # Loop outputs
     cond_out = loop_block.op("Cast", loop_condition, to_i=9)
-    torch.onnx.utils._add_output_to_block(loop_block, cond_out)
-    torch.onnx.utils._add_output_to_block(loop_block, final_splits)
+    _add_output_to_block(loop_block, cond_out)
+    _add_output_to_block(loop_block, final_splits)
 
     loop_out = loop.node().output()
     loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
@@ -514,7 +518,7 @@
     if_op = g.op("If", overrun_cond)
     if_node = if_op.node()
 
-    if_block = torch.onnx.utils._add_block(if_node)
+    if_block = _add_block(if_node)
     gather_indices_if_block = if_block.op("Add", gather_indices, select_window)
     gather_indices_if_block = sym_help._unsqueeze_helper(
         if_block, gather_indices_if_block, [rank - 1]
@@ -522,11 +526,11 @@
     final_non_overrun_ = if_block.op(
         "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
     )
-    torch.onnx.utils._add_output_to_block(if_block, final_non_overrun_)
+    _add_output_to_block(if_block, final_non_overrun_)
 
-    else_block = torch.onnx.utils._add_block(if_node)
+    else_block = _add_block(if_node)
     final_overrun_ = zeros(else_block, gather_shape, 6, None, None)
-    torch.onnx.utils._add_output_to_block(else_block, final_overrun_)
+    _add_output_to_block(else_block, final_overrun_)
     return if_op
 
 
diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py
index b46b635..fe140fe 100644
--- a/torch/onnx/symbolic_opset14.py
+++ b/torch/onnx/symbolic_opset14.py
@@ -4,8 +4,7 @@
 # This file exports ONNX ops for opset 14
 import torch
 import torch.onnx.symbolic_helper as sym_help
-from torch.onnx._globals import GLOBALS
-from torch.onnx.symbolic_helper import parse_args
+from torch.onnx.symbolic_helper import args_have_same_dtype, parse_args
 
 # Note [ONNX operators that are added/updated in opset 14]
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -59,10 +58,8 @@
 
     if (
         torch.is_autocast_enabled()
-        and not sym_help.args_have_same_dtype(
-            [input, weight, bias, running_mean, running_var]
-        )
-        and GLOBALS.export_onnx_opset_version < 15
+        and not args_have_same_dtype([input, weight, bias, running_mean, running_var])
+        and sym_help._export_onnx_opset_version < 15
     ):
         return sym_help._onnx_opset_unsupported_detailed(
             "BatchNormalization",
diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py
index 6f2da64..72aa535 100644
--- a/torch/onnx/symbolic_opset8.py
+++ b/torch/onnx/symbolic_opset8.py
@@ -10,6 +10,7 @@
     _unimplemented,
     parse_args,
 )
+from torch.onnx.symbolic_opset9 import _cast_Float  # type: ignore[attr-defined]
 
 # Note [ONNX operators that are added/updated from opset 8 to opset 9]
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -128,11 +129,7 @@
     if arg0_type is not None:
         old_type = arg0_type
         if old_type not in floating_scalar_types:
-            # TODO(justinchuby): Remove the type ignore hint once _cast_Float is
-            # properly defined.
-            # NOTE: _cast_Float is generated programmatically so we need to make the
-            # type checker happy with ignore[attr-defined].
-            args = tuple(sym_opset9._cast_Float(g, arg, False) for arg in args)  # type: ignore[attr-defined]
+            args = tuple(_cast_Float(g, arg, False) for arg in args)
         else:
             return (None,) + args
     else:
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index ec27e8d..e6c5978 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1,3 +1,5 @@
+# -*- coding: utf-8 -*-
+
 import math
 import warnings
 from functools import partial, wraps
@@ -7,13 +9,12 @@
 import torch
 import torch.onnx
 import torch.onnx.symbolic_helper as sym_help
-from torch._C import ListType, OptionalType
-from torch.nn.modules.utils import _pair, _single, _triple
 
 # This import monkey-patches graph manipulation methods on Graph, used for the
 # ONNX symbolics
-from torch.onnx import _patch_torch  # noqa: F401
-from torch.onnx._globals import GLOBALS
+import torch.onnx.utils
+from torch._C import ListType, OptionalType
+from torch.nn.modules.utils import _pair, _single, _triple
 from torch.onnx.symbolic_helper import (
     ScalarType,
     _parse_arg,
@@ -528,12 +529,12 @@
 
 @parse_args("v", "v", "i", "b", "v")
 def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
-    if scale_grad_by_freq and GLOBALS.training_mode:
+    if scale_grad_by_freq and sym_help._training_mode:
         raise RuntimeError(
             "Unsupported: ONNX export of embedding with scale_grad_by_freq=True "
             "for training mode. ONNX does not support scaling the gradients."
         )
-    if padding_idx >= 0 and GLOBALS.training_mode:
+    if padding_idx >= 0 and sym_help._training_mode:
         warnings.warn(
             "Warning: ONNX export of embedding with padding_idx >= 0 "
             "for training mode. "
@@ -837,7 +838,7 @@
     dtype_0 = inputs[0].type().scalarType()
 
     require_cast = not sym_help._is_fp(inputs[0]) and (
-        opset_before is None or GLOBALS.export_onnx_opset_version < opset_before
+        opset_before is None or sym_help._export_onnx_opset_version < opset_before
     )
 
     if require_cast:
@@ -1549,7 +1550,7 @@
     if condition.type().scalarType() != "Bool":
         condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx["Bool"])
     if self is None:
-        condition = nonzero(g, condition)
+        condition = torch.onnx.symbolic_opset9.nonzero(g, condition)
         return sym_help._unbind_helper(
             g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
         )
@@ -1791,7 +1792,7 @@
     if (
         torch.is_autocast_enabled()
         and not args_have_same_dtype([input, weight, bias, running_mean, running_var])
-        and GLOBALS.export_onnx_opset_version < 15
+        and sym_help._export_onnx_opset_version < 15
     ):
         return sym_help._onnx_opset_unsupported_detailed(
             "BatchNormalization",
@@ -2371,8 +2372,6 @@
         sym_help._onnx_opset_unsupported("_unique2", 9, 11)
 
 
-# TODO(justinchuby): Clean up this function generation magic by defining the functions
-# explicitly.
 for k, v in sym_help.cast_pytorch_to_onnx.items():
     name = "_cast_{}".format(k)
     globals()[name] = parse_args("v", "i")(partial(sym_help._cast_func_template, v))
@@ -2610,7 +2609,7 @@
             or ((not is_end_none) and (not is_end_onnx_const))
             or dim.node().kind() != "onnx::Constant"
         ):
-            if GLOBALS.operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
+            if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
                 raise RuntimeError(
                     "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
                     "is a deprecated experimental op. Please use statically allocated "
@@ -3953,7 +3952,7 @@
         if not sym_help._is_none(index) and (
             index.type().scalarType() == "Byte" or index.type().scalarType() == "Bool"
         ):
-            if GLOBALS.export_onnx_opset_version < 9:
+            if sym_help._export_onnx_opset_version < 9:
                 raise RuntimeError(
                     "Exporting masked indices are only supported after ONNX opset 9."
                 )
@@ -4009,7 +4008,7 @@
             #       update the warning to recommend exporting with higher opset version.
             warnings.warn(
                 "Exporting aten::index operator of advanced indexing in opset "
-                + str(GLOBALS.export_onnx_opset_version)
+                + str(sym_help._export_onnx_opset_version)
                 + " is achieved by combination of multiple ONNX operators, "
                 + "including Reshape, Transpose, Concat, and Gather. "
                 + "If indices include negative values, the exported graph will produce incorrect results."
@@ -4827,8 +4826,8 @@
         env = ctx.env
         params_dict = ctx.params_dict
 
-        operator_export_type = GLOBALS.operator_export_type
-        opset_version = GLOBALS.export_onnx_opset_version
+        operator_export_type = sym_help._operator_export_type
+        opset_version = sym_help._export_onnx_opset_version
 
         new_op_outputs = g.op("Loop", *inputs, outputs=n.outputsSize())
         new_node = (
@@ -4861,7 +4860,9 @@
             new_node, opset_version
         )
         # Run shape type inference for Loop after subblock is converted.
-        if GLOBALS.onnx_shape_inference:
+        from torch.onnx.symbolic_helper import _onnx_shape_inference
+
+        if _onnx_shape_inference:
             torch._C._jit_pass_onnx_node_shape_type_inference(
                 new_node, params_dict, opset_version
             )
@@ -4874,8 +4875,8 @@
         env = ctx.env
         params_dict = ctx.params_dict
 
-        operator_export_type = GLOBALS.operator_export_type
-        opset_version = GLOBALS.export_onnx_opset_version
+        operator_export_type = sym_help._operator_export_type
+        opset_version = sym_help._export_onnx_opset_version
 
         static_if = inputs[0].node().kind() == "onnx::Constant"
         if static_if:
@@ -4950,7 +4951,9 @@
                 new_node, opset_version
             )
             # Run shape type inference for If after subblock is converted.
-            if GLOBALS.onnx_shape_inference:
+            from torch.onnx.symbolic_helper import _onnx_shape_inference
+
+            if _onnx_shape_inference:
                 torch._C._jit_pass_onnx_node_shape_type_inference(
                     new_node, params_dict, opset_version
                 )
diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py
index 88286c8..d90772d1 100644
--- a/torch/onnx/symbolic_registry.py
+++ b/torch/onnx/symbolic_registry.py
@@ -1,13 +1,12 @@
 import importlib
 import inspect
-import itertools
 import warnings
 from typing import Any, Callable, Dict, Tuple, Union
 
 import torch._C
-from torch.onnx import _constants
+from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
 
-_SymbolicFunction = Callable[..., Union[torch._C.Value, Tuple[torch._C.Value]]]
+SymbolicFunction = Callable[..., Union[torch._C.Value, Tuple[torch._C.Value]]]
 
 """
 The symbolic registry "_registry" is a dictionary that maps operators
@@ -19,21 +18,16 @@
 """
 _registry: Dict[
     Tuple[str, int],
-    Dict[str, _SymbolicFunction],
+    Dict[str, SymbolicFunction],
 ] = {}
 
 _symbolic_versions: Dict[Union[int, str], Any] = {}
 
-
-def _import_symbolic_opsets():
-    for opset_version in itertools.chain(
-        _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
-    ):
-        module = importlib.import_module(
-            "torch.onnx.symbolic_opset{}".format(opset_version)
-        )
-        global _symbolic_versions
-        _symbolic_versions[opset_version] = module
+for opset_version in _onnx_stable_opsets + [_onnx_main_opset]:
+    module = importlib.import_module(
+        "torch.onnx.symbolic_opset{}".format(opset_version)
+    )
+    _symbolic_versions[opset_version] = module
 
 
 def register_version(domain: str, version: int):
@@ -135,7 +129,7 @@
 
 def get_op_supported_version(opname: str, domain: str, version: int):
     iter_version = version
-    while iter_version <= _constants.onnx_main_opset:
+    while iter_version <= _onnx_main_opset:
         ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
         if (domain, opname) in ops:
             return iter_version
@@ -143,7 +137,7 @@
     return None
 
 
-def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
+def get_registered_op(opname: str, domain: str, version: int) -> SymbolicFunction:
     if domain is None or version is None:
         warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
     global _registry
@@ -155,7 +149,7 @@
 class UnsupportedOperatorError(RuntimeError):
     def __init__(self, domain: str, opname: str, version: int):
         supported_version = get_op_supported_version(opname, domain, version)
-        if domain in {"", "aten", "prim", "quantized"}:
+        if domain in ["", "aten", "prim", "quantized"]:
             msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. "
             if supported_version is not None:
                 msg += (
@@ -171,6 +165,3 @@
                 "it with the right domain and version."
             )
         super().__init__(msg)
-
-
-_import_symbolic_opsets()
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 30b2c6a..cb02ee4 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -6,26 +6,16 @@
 import contextlib
 import copy
 import inspect
-import itertools
+import numbers
 import os
 import re
 import textwrap
 import typing
 import warnings
-import zipfile
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
 
 import torch
-import torch.jit._trace
 import torch.serialization
-from torch.onnx import (  # noqa: F401
-    _constants,
-    _patch_torch,
-    symbolic_caffe2,
-    symbolic_helper,
-    symbolic_registry,
-)
-from torch.onnx._globals import GLOBALS
 
 # the flag to tell the user whether it's in the middle of ONNX export or not
 __IN_ONNX_EXPORT = False
@@ -36,7 +26,6 @@
     return __IN_ONNX_EXPORT
 
 
-# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
 # Skip check due to cannot import IValue from torch._C
 _params_dict = {}  # type: ignore[var-annotated]
 
@@ -66,24 +55,26 @@
         if mode == torch.onnx.TrainingMode.TRAINING or (
             mode == torch.onnx.TrainingMode.PRESERVE and is_originally_training
         ):
+            from torch.onnx.symbolic_helper import _export_onnx_opset_version
 
-            if GLOBALS.export_onnx_opset_version < 12:
+            if _export_onnx_opset_version < 12:
                 warnings.warn(
                     "You are exporting the model in training mode with onnx opset version {}. "
                     "Opset versions lower than opset 12 will not be able to export nodes such as "
                     "Dropout and BatchNorm correctly.".format(
-                        GLOBALS.export_onnx_opset_version
+                        _export_onnx_opset_version
                     )
                 )
             is_export_training = True
 
-        symbolic_helper._set_training_mode(is_export_training)
+        from torch.onnx.symbolic_helper import _set_training_mode
+
+        _set_training_mode(is_export_training)
         model.train(is_export_training)
     try:
         yield
     finally:
         if not isinstance(model, torch.jit.ScriptFunction):
-            # FIXME(justinchuby): is_originally_training is possibly unbound
             model.train(is_originally_training)
 
 
@@ -108,7 +99,6 @@
         yield
     finally:
         if not isinstance(model, torch.jit.ScriptFunction):
-            # FIXME(justinchuby): tmp_map is possibly unbound
             for module, m_map in tmp_map.items():
                 for k, v in m_map.items():
                     module._state_dict_hooks[k] = v
@@ -247,6 +237,11 @@
     torch._C._jit_pass_peephole(graph, True)
     torch._C._jit_pass_fuse_addmm(graph)
     torch._C._jit_pass_lint(graph)
+    from torch.onnx.symbolic_helper import (
+        _export_onnx_opset_version,
+        _onnx_shape_inference,
+        is_caffe2_aten_fallback,
+    )
 
     torch._C._jit_pass_peephole(graph, True)
     torch._C._jit_pass_lower_all_tuples(graph)
@@ -268,12 +263,12 @@
     torch._C._jit_pass_onnx_remove_print(graph)
     torch._C._jit_pass_onnx_preprocess_caffe2(graph)
 
-    symbolic_helper._quantized_ops.clear()
+    torch.onnx.symbolic_helper._quantized_ops.clear()
     # Unpack quantized weights for conv and linear ops and insert into graph.
     torch._C._jit_pass_onnx_unpack_quantized_weights(
-        graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
+        graph, params_dict, is_caffe2_aten_fallback()
     )
-    if symbolic_helper.is_caffe2_aten_fallback():
+    if is_caffe2_aten_fallback():
         # Insert permutes before and after each conv op to ensure correct order.
         torch._C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
 
@@ -296,7 +291,7 @@
 
     # onnx only supports tensors, so we turn all out number types into tensors
     torch._C._jit_pass_erase_number_types(graph)
-    if GLOBALS.onnx_shape_inference:
+    if _onnx_shape_inference:
         input_names = [] if input_names is None else input_names
         dynamic_axes = {} if dynamic_axes is None else dynamic_axes
         torch._C._jit_pass_onnx_set_dynamic_input_shape(
@@ -308,12 +303,12 @@
     torch._C._jit_pass_lint(graph)
 
     torch._C._jit_pass_onnx_scalar_type_analysis(
-        graph, True, GLOBALS.export_onnx_opset_version
+        graph, True, _export_onnx_opset_version
     )
     torch._C._jit_pass_lint(graph)
 
     torch._C._jit_pass_onnx_peephole(
-        graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
+        graph, _export_onnx_opset_version, fixed_batch_size
     )
     torch._C._jit_pass_lint(graph)
 
@@ -326,9 +321,9 @@
     torch._C._jit_pass_lint(graph)
     graph = torch._C._jit_pass_canonicalize(graph)
     torch._C._jit_pass_lint(graph)
-    if GLOBALS.onnx_shape_inference:
+    if _onnx_shape_inference:
         torch._C._jit_pass_onnx_graph_shape_type_inference(
-            graph, params_dict, GLOBALS.export_onnx_opset_version
+            graph, params_dict, _export_onnx_opset_version
         )
     return graph
 
@@ -591,7 +586,6 @@
         graph = model.graph
         torch._C._jit_pass_onnx_function_substitution(graph)
         param_count_list = _get_param_count_list(graph, args)
-        # FIXME(justinchuby): flattened_args is possibly unbound
         graph = torch._C._propagate_and_assign_input_shapes(
             graph, flattened_args, param_count_list, False, False
         )
@@ -717,6 +711,7 @@
             this will be None, since we are not doing any tracing.
     """
     # TODO: can we simplify this to always return a tuple of Tensor or None?
+    from torch.onnx.symbolic_helper import _export_onnx_opset_version
 
     # Special case for common case of passing a single Tensor
     if isinstance(args, (torch.Tensor, int, float, bool)):
@@ -740,6 +735,7 @@
     except Exception as e:
         torch.onnx.log("Torch IR graph at exception: ", graph)
         raise
+    from torch.onnx.symbolic_helper import _onnx_shape_inference
 
     is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
     if is_script:
@@ -749,7 +745,7 @@
             example_outputs_final += unpack_quantized_tensor(example_output)
         out_vars, desc = torch.jit._flatten(example_outputs_final)
         torch._C._jit_pass_onnx_assign_output_shape(
-            graph, out_vars, desc, GLOBALS.onnx_shape_inference, is_script
+            graph, out_vars, desc, _onnx_shape_inference, is_script
         )
 
     # NB: ONNX requires complete information about output types, which might be
@@ -766,11 +762,7 @@
         # single value in PyTorch.
         if not any(getattr(out, "is_quantized", False) for out in output_tensors):
             torch._C._jit_pass_onnx_assign_output_shape(
-                graph,
-                output_tensors,
-                out_desc,
-                GLOBALS.onnx_shape_inference,
-                is_script,
+                graph, output_tensors, out_desc, _onnx_shape_inference, is_script
             )
 
     _set_input_and_output_names(graph, input_names, output_names)
@@ -779,25 +771,27 @@
     if training is None or training == torch.onnx.TrainingMode.EVAL:
         params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
 
+    from torch.onnx.symbolic_helper import _constant_folding_opset_versions
+
     if (
         do_constant_folding
-        and GLOBALS.export_onnx_opset_version in _constants.onnx_constant_folding_opsets
+        and _export_onnx_opset_version in _constant_folding_opset_versions
     ):
         params_dict = torch._C._jit_pass_onnx_constant_fold(
-            graph, params_dict, GLOBALS.export_onnx_opset_version
+            graph, params_dict, _export_onnx_opset_version
         )
         torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
 
-    if GLOBALS.onnx_shape_inference:
+    if _onnx_shape_inference:
         torch._C._jit_pass_onnx_graph_shape_type_inference(
-            graph, params_dict, GLOBALS.export_onnx_opset_version
+            graph, params_dict, _export_onnx_opset_version
         )
 
     params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
 
     # For ONNX opset < 9, constants only have three data types: float16, float, double.
     # In this pass transform constants of other data types to float/double + cast operator.
-    if GLOBALS.export_onnx_opset_version < 9:
+    if _export_onnx_opset_version < 9:
         torch._C._jit_pass_onnx_cast_all_constant_to_floating(graph)
 
     params_dict = torch._C._jit_pass_filter_non_tensor_arguments(params_dict)
@@ -828,15 +822,21 @@
     do_constant_folding=True,
     dynamic_axes=None,
 ):
+    from torch.onnx.symbolic_helper import (
+        _default_onnx_opset_version,
+        _set_operator_export_type,
+        _set_opset_version,
+    )
 
     if opset_version is None:
-        opset_version = _constants.onnx_default_opset
+        opset_version = _default_onnx_opset_version
     if custom_opsets is None:
         custom_opsets = {}
-    symbolic_helper._set_opset_version(opset_version)
-    symbolic_helper._set_operator_export_type(operator_export_type)
+    _set_opset_version(opset_version)
+    _set_operator_export_type(operator_export_type)
+    from torch.onnx.symbolic_helper import _set_onnx_shape_inference
 
-    symbolic_helper._set_onnx_shape_inference(True)
+    _set_onnx_shape_inference(True)
     with exporter_context(model, training, verbose):
         val_keep_init_as_ip = _decide_keep_init_as_input(
             keep_initializers_as_inputs, operator_export_type, opset_version
@@ -890,9 +890,13 @@
         Tuple[torch._C.Graph, List[str]], where the list includes the names
         of the unconvertible ops.
     """
+    from torch.onnx.symbolic_helper import (
+        _default_onnx_opset_version,
+        _set_opset_version,
+    )
 
-    opset_version = opset_version or _constants.onnx_default_opset
-    symbolic_helper._set_opset_version(opset_version)
+    opset_version = opset_version or _default_onnx_opset_version
+    _set_opset_version(opset_version)
     # operator_export_type is set to ONNX_FALLTHROUGH by default so that if an op is not supported
     # in ONNX, fall through will occur and export the operator as is, as a custom ONNX op.
     with exporter_context(model, training, False):
@@ -977,9 +981,10 @@
 
 
 def _get_module_attributes(module):
+    from typing import get_type_hints
 
-    annotations = typing.get_type_hints(type(module))
-    base_m_annotations = typing.get_type_hints(torch.nn.Module)
+    annotations = get_type_hints(type(module))
+    base_m_annotations = get_type_hints(torch.nn.Module)
     [annotations.pop(k, None) for k in base_m_annotations]
     return {k: getattr(module, k) for k in annotations}
 
@@ -1017,11 +1022,18 @@
     assert __IN_ONNX_EXPORT is False
     __IN_ONNX_EXPORT = True
     try:
+        from torch.onnx.symbolic_helper import _set_onnx_shape_inference
 
-        symbolic_helper._set_onnx_shape_inference(onnx_shape_inference)
+        _set_onnx_shape_inference(onnx_shape_inference)
+
+        from torch.onnx.symbolic_helper import (
+            _default_onnx_opset_version,
+            _set_operator_export_type,
+            _set_opset_version,
+        )
 
         if opset_version is None:
-            opset_version = _constants.onnx_default_opset
+            opset_version = _default_onnx_opset_version
 
         if export_modules_as_functions and opset_version < 15:
             raise ValueError(
@@ -1045,8 +1057,8 @@
         # If you really know what you're doing, you can turn
         # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
         # (to preserve whatever the original training mode was.)
-        symbolic_helper._set_opset_version(opset_version)
-        symbolic_helper._set_operator_export_type(operator_export_type)
+        _set_opset_version(opset_version)
+        _set_operator_export_type(operator_export_type)
         with exporter_context(model, training, verbose):
             val_keep_init_as_ip = _decide_keep_init_as_input(
                 keep_initializers_as_inputs, operator_export_type, opset_version
@@ -1148,6 +1160,8 @@
                 torch.onnx.ExportTypes.ZIP_ARCHIVE,
                 torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
             ]:
+                import zipfile
+
                 compression = (
                     zipfile.ZIP_DEFLATED
                     if export_type == torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE
@@ -1243,6 +1257,9 @@
     set_names(list(graph.outputs()), output_names, "output")
 
 
+_attr_pattern = re.compile("^(.+)_(([ifstgz])|(ty))$")
+
+
 def _run_symbolic_method(g, op_name, symbolic_fn, args):
     r"""
     This trampoline function gets invoked for every symbolic method
@@ -1258,6 +1275,151 @@
         raise
 
 
+def _is_onnx_list(value):
+    return (
+        not isinstance(value, torch._six.string_classes)
+        and not isinstance(value, torch.Tensor)
+        and isinstance(value, Iterable)
+    )
+
+
+def _add_attribute(node, key, value, aten):
+    r"""Initializes the right attribute based on type of value."""
+    m = _attr_pattern.match(key)
+    if m is None:
+        raise IndexError(
+            (
+                "Invalid attribute specifier '{}' names "
+                + " must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
+            ).format(key)
+        )
+    name, kind = m.group(1), m.group(2)
+    if _is_onnx_list(value):
+        kind += "s"
+    from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
+
+    if aten and is_caffe2_aten_fallback():
+        if isinstance(value, torch.Tensor):
+            # Caffe2 proto does not support tensor attribute.
+            if value.numel() > 1:
+                raise ValueError("Should not pass tensor attribute")
+            value = _scalar(value)
+            if isinstance(value, float):
+                kind = "f"
+            else:
+                kind = "i"
+    return getattr(node, kind + "_")(name, value)
+
+
+def _scalar(x):
+    """Convert a scalar tensor into a Python value."""
+    assert x.numel() == 1
+    return x[0]
+
+
+def _new_node(g: torch._C.Graph, opname: str, outputs, *args, **kwargs):
+    if "::" in opname:
+        aten = False
+        ns_opname = opname
+    else:
+        aten = kwargs.pop("aten", False)
+        ns = "aten" if aten else "onnx"
+        ns_opname = ns + "::" + opname
+    n = g.create(ns_opname, args, outputs)  # type: ignore[attr-defined]
+    for k, v in sorted(kwargs.items()):
+        # TODO: enable inplace in aten exporting mode.
+        if k == "inplace":
+            continue
+        _add_attribute(n, k, v, aten=aten)
+    return n
+
+
+def _graph_op(
+    g: torch._C.Graph,
+    opname: str,
+    *raw_args: torch._C.Node,
+    outputs: int = 1,
+    **kwargs,
+) -> Union[torch._C.Value, Tuple[torch._C.Value, ...]]:
+    r"""Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
+
+    The set of operators and the inputs/attributes they take
+    is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
+
+    This function is monkey-patched onto Graph.
+
+    Args:
+        g: The Torch graph.
+        opname: The ONNX operator name, e.g., `Abs` or `Add`. TODO(justinchu): Update examples to correct ones.
+        raw_args: The inputs to the operator; usually provided
+            as arguments to the `symbolic` definition.
+        outputs: The number of outputs this operator returns.
+            By default an operator is assumed to return a single output.
+            If `outputs` is greater than one, this functions returns a tuple
+            of output `Node`, representing each output of the ONNX operator
+            in positional.
+        kwargs: The attributes of the ONNX operator, whose keys are named
+            according to the following convention: `alpha_f` indicates
+            the `alpha` attribute with type `f`.  The valid type specifiers are
+            `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
+            specified with type float accepts either a single float, or a
+            list of floats (e.g., you would say `dims_i` for a `dims` attribute
+            that takes a list of integers).
+
+    Returns:
+        The node representing the single output of this operator (see the `outputs`
+        keyword argument for multi-return nodes).
+    """
+    # Filter out None attributes, this can be convenient client side because
+    # now they can pass through None attributes, and have them not show up
+    kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
+
+    def const_if_tensor(arg):
+        if arg is None:
+            return arg
+        elif isinstance(arg, torch._C.Value):
+            return arg
+        else:
+            return g.op("Constant", value_z=arg)  # type: ignore[attr-defined]
+
+    args = [const_if_tensor(arg) for arg in raw_args]
+    n = g.insertNode(_new_node(g, opname, outputs, *args, **kwargs))  # type: ignore[attr-defined]
+
+    from torch.onnx.symbolic_helper import _onnx_shape_inference
+
+    if _onnx_shape_inference:
+        from torch.onnx.symbolic_helper import (
+            _export_onnx_opset_version as opset_version,
+        )
+
+        torch._C._jit_pass_onnx_node_shape_type_inference(
+            n, _params_dict, opset_version
+        )
+
+    if outputs == 1:
+        return n.output()
+    return tuple(n.outputs())
+
+
+def _block_op(b, opname, *args, **kwargs):
+    if "::" in opname:
+        aten = False
+        ns_opname = opname
+    else:
+        aten = kwargs.pop("aten", False)
+        ns = "aten" if aten else "onnx"
+        ns_opname = ns + "::" + opname
+    n = b.addNode(ns_opname, list(args))
+    for k, v in sorted(kwargs.items()):
+        # TODO: enable inplace in aten exporting mode.
+        if k == "inplace":
+            continue
+        _add_attribute(n, k, v, aten=aten)
+    if len(list(n.outputs())) == 1:
+        return n.output()
+    return tuple(o for o in n.outputs())
+
+
 def _add_block(node: torch._C.Node):
     return node.addBlock()  # type: ignore[attr-defined]
 
@@ -1297,19 +1459,19 @@
     Returns:
         The symbolic function if found, None otherwise.
     """
+    import torch.onnx.symbolic_registry as sym_registry
 
-    if not symbolic_registry.is_registered_op(op_name, domain, opset_version):
+    if not sym_registry.is_registered_op(op_name, domain, opset_version):
         if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
             # Use the original node directly
             return None
-    return symbolic_registry.get_registered_op(op_name, domain, opset_version)
+    return sym_registry.get_registered_op(op_name, domain, opset_version)
 
 
 def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
+    import torch.onnx.symbolic_registry as sym_registry
 
-    is_exportable_aten_op = symbolic_registry.is_registered_op(
-        op_name, "", opset_version
-    )
+    is_exportable_aten_op = sym_registry.is_registered_op(op_name, "", opset_version)
     is_onnx_aten_export = (
         operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN
     )
@@ -1337,10 +1499,11 @@
 
 
 def _get_aten_op_overload_name(n: torch._C.Node) -> str:
+    from torch.onnx.symbolic_helper import is_caffe2_aten_fallback
 
     # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
     schema = n.schema()
-    if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
+    if not schema.startswith("aten::") or is_caffe2_aten_fallback():
         return ""
     return torch._C.parse_schema(schema).overload_name
 
@@ -1361,9 +1524,11 @@
         A single or a tuple of Values.
         None when the node gets cloned as is into the new graph.
     """
+    from torch.onnx import symbolic_helper
+    from torch.onnx import symbolic_registry as sym_registry
 
-    opset_version = GLOBALS.export_onnx_opset_version
-    symbolic_helper.is_caffe2_aten_fallback = symbolic_helper.is_caffe2_aten_fallback
+    opset_version = symbolic_helper._export_onnx_opset_version
+    is_caffe2_aten_fallback = symbolic_helper.is_caffe2_aten_fallback
 
     # See Note [Export inplace]
     # TODO(ezyang): I think this is not necessary anymore
@@ -1374,21 +1539,22 @@
     ns, op_name = ns_op_name.split("::")
 
     try:
-        symbolic_registry.register_version("", opset_version)
+        sym_registry.register_version("", opset_version)
 
         # Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
-        if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
+        if is_caffe2_aten_fallback() and opset_version == 9:
+            from torch.onnx import symbolic_caffe2
 
             symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
 
         if ns == "aten":
             domain = ""
-        elif ns == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
+        elif ns == "quantized" and is_caffe2_aten_fallback():
             domain = "caffe2"
         else:
             domain = ns
 
-        if symbolic_registry.is_registered_op(op_name, domain, opset_version):
+        if sym_registry.is_registered_op(op_name, domain, opset_version):
             symbolic_fn = _find_symbolic_in_registry(
                 domain, op_name, opset_version, operator_export_type
             )
@@ -1417,15 +1583,13 @@
                 op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
             )
         else:
-            raise symbolic_registry.UnsupportedOperatorError(
-                domain, op_name, opset_version
-            )
+            raise sym_registry.UnsupportedOperatorError(domain, op_name, opset_version)
     except RuntimeError:
         if operator_export_type == torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
             return None
         elif (
             operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
-            and not symbolic_helper.is_caffe2_aten_fallback()
+            and not is_caffe2_aten_fallback()
         ):
             # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
             attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}  # type: ignore[attr-defined]
@@ -1440,6 +1604,78 @@
         raise
 
 
+# Generate an ONNX ATen op node.
+def _aten_op(g, operator, *args, overload_name="", **kwargs):
+    kwargs["aten"] = True
+    return g.op(
+        "ATen", *args, operator_s=operator, overload_name_s=overload_name, **kwargs
+    )
+
+
+# TODO: We might not need this anymore, since most scalars now show up as tensors
+# TODO(#76254): Remove the helper function if not needed.
+def _graph_constant(
+    g,
+    value,
+    dims,
+    type_: str,
+    *args,
+    **kwargs,
+):
+    """This helper function can create either constant tensor or constant scalar.
+
+    If dims is None or 0 or [0], generate a 0-d tensor (scalar).
+    """
+    assert isinstance(value, numbers.Number)
+    assert type_ is not None
+    isscalar = False
+    if dims is None or dims == 0 or set(dims) == set([0]):
+        dims = [1]
+        isscalar = True
+    type_ = type_.lower()
+    tensor: Union[
+        torch.CharTensor,
+        torch.ShortTensor,
+        torch.IntTensor,
+        torch.LongTensor,
+        torch.HalfTensor,
+        torch.FloatTensor,
+        torch.DoubleTensor,
+    ]
+    if type_ == "char":
+        tensor = torch.CharTensor(*dims)
+    elif type_ == "short":
+        tensor = torch.ShortTensor(*dims)
+    elif type_ == "int":
+        tensor = torch.IntTensor(*dims)
+    elif type_ == "long":
+        tensor = torch.LongTensor(*dims)
+    elif type_ == "half":
+        tensor = torch.HalfTensor(*dims)
+    elif type_ == "float":
+        tensor = torch.FloatTensor(*dims)
+    elif type_ == "double":
+        tensor = torch.DoubleTensor(*dims)
+    else:
+        raise ValueError(
+            "Unknown type, type should be one of the following strings: "
+            "char, short, int, long, half, float, double"
+        )
+    tensor.fill_(value)  # type: ignore[call-overload]
+    if isscalar:
+        return g.op("Constant", *args, value_z=tensor, **kwargs)
+    return g.op("Constant", *args, value_t=tensor, **kwargs)
+
+
+def _node_getitem(self, k):
+    """Gets attributes of a node which is polymorphic over return type.
+
+    This is monkey-patched onto Node.
+    """
+    sel = self.kindOf(k)
+    return getattr(self, sel)(k)
+
+
 def get_ns_op_name_from_custom_op(symbolic_name):
     if not bool(
         re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
@@ -1472,22 +1708,22 @@
     An example of setType is `test_aten_embedding_2` in `test_operators.py`.
     """
     ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
+    import torch.onnx.symbolic_registry as sym_registry
+    from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
 
-    for version in itertools.chain(
-        _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
-    ):
+    for version in _onnx_stable_opsets + [_onnx_main_opset]:
         if version >= opset_version:
-            symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
+            sym_registry.register_op(op_name, symbolic_fn, ns, version)
 
 
 def unregister_custom_op_symbolic(symbolic_name, opset_version):
     ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
+    import torch.onnx.symbolic_registry as sym_registry
+    from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets
 
-    for version in itertools.chain(
-        _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
-    ):
+    for version in _onnx_stable_opsets + [_onnx_main_opset]:
         if version >= opset_version:
-            symbolic_registry.unregister_op(op_name, ns, version)
+            sym_registry.unregister_op(op_name, ns, version)
 
 
 def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
@@ -1538,3 +1774,10 @@
                 else:
                     value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
             dynamic_axes[key] = value_dict
+
+
+torch._C.Graph.op = _graph_op  # type: ignore[attr-defined]
+torch._C.Graph.at = _aten_op  # type: ignore[attr-defined]
+torch._C.Block.op = _block_op  # type: ignore[attr-defined]
+torch._C.Graph.constant = _graph_constant  # type: ignore[attr-defined]
+torch._C.Node.__getitem__ = _node_getitem  # type: ignore[attr-defined, misc, assignment]