blob: eb88c4f83289eadcdaabe2ec4e021a49b341d51a [file] [log] [blame]
import importlib
import inspect
import itertools
import warnings
from typing import Any, Callable, Dict, Tuple, Union
from torch import _C
from torch.onnx import _constants, errors
__all__ = [
"get_op_supported_version",
"get_ops_in_version",
"get_registered_op",
"is_registered_op",
"is_registered_version",
"register_op",
"register_ops_helper",
"register_ops_in_version",
"register_version",
"unregister_op",
]
_SymbolicFunction = Callable[..., Union[_C.Value, Tuple[_C.Value]]]
"""
The symbolic registry "_registry" is a dictionary that maps operators
(for a specific domain and opset version) to their symbolic functions.
An operator is defined by its domain, opset version, and opname.
The keys are tuples (domain, version), (where domain is a string, and version is an int),
and the operator's name (string).
The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
"""
_registry: Dict[
Tuple[str, int],
Dict[str, _SymbolicFunction],
] = {}
_symbolic_versions: Dict[Union[int, str], Any] = {}
def _import_symbolic_opsets():
for opset_version in itertools.chain(
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
):
module = importlib.import_module(f"torch.onnx.symbolic_opset{opset_version}")
global _symbolic_versions
_symbolic_versions[opset_version] = module
def register_version(domain: str, version: int):
if not is_registered_version(domain, version):
global _registry
_registry[(domain, version)] = {}
register_ops_in_version(domain, version)
def register_ops_helper(domain: str, version: int, iter_version: int):
for domain, op_name, op_func in get_ops_in_version(iter_version):
if not is_registered_op(op_name, domain, version):
register_op(op_name, op_func, domain, version)
def register_ops_in_version(domain: str, version: int):
"""Iterates through the symbolic functions of the specified opset version, and the
previous opset versions for operators supported in previous versions.
Opset 9 is the base version. It is selected as the base version because
1. It is the first opset version supported by PyTorch export.
2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
we chose to handle them as special cases separately.
Backward support for opset versions beyond opset 7 is not in our roadmap.
For opset versions other than 9, by default they will inherit the symbolic functions defined in
symbolic_opset9.py.
To extend support for updated operators in different opset versions on top of opset 9,
simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
"""
iter_version = version
while iter_version != 9:
register_ops_helper(domain, version, iter_version)
if iter_version > 9:
iter_version = iter_version - 1
else:
iter_version = iter_version + 1
register_ops_helper(domain, version, 9)
def get_ops_in_version(version: int):
if not _symbolic_versions:
_import_symbolic_opsets()
members = inspect.getmembers(_symbolic_versions[version])
domain_opname_ops = []
for obj in members:
if isinstance(obj[1], type) and hasattr(obj[1], "domain"):
ops = inspect.getmembers(obj[1], predicate=inspect.isfunction)
for op in ops:
domain_opname_ops.append((obj[1].domain, op[0], op[1])) # type: ignore[attr-defined]
elif inspect.isfunction(obj[1]):
if obj[0] == "_len":
obj = ("len", obj[1])
if obj[0] == "_list":
obj = ("list", obj[1])
if obj[0] == "_any":
obj = ("any", obj[1])
if obj[0] == "_all":
obj = ("all", obj[1])
domain_opname_ops.append(("", obj[0], obj[1]))
return domain_opname_ops
def is_registered_version(domain: str, version: int):
global _registry
return (domain, version) in _registry
def register_op(opname, op, domain, version):
if domain is None or version is None:
warnings.warn(
"ONNX export failed. The ONNX domain and/or version to register are None."
)
global _registry
if not is_registered_version(domain, version):
_registry[(domain, version)] = {}
_registry[(domain, version)][opname] = op
def is_registered_op(opname: str, domain: str, version: int):
if domain is None or version is None:
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
global _registry
return (domain, version) in _registry and opname in _registry[(domain, version)]
def unregister_op(opname: str, domain: str, version: int):
global _registry
if is_registered_op(opname, domain, version):
del _registry[(domain, version)][opname]
if not _registry[(domain, version)]:
del _registry[(domain, version)]
else:
warnings.warn("The opname " + opname + " is not registered.")
def get_op_supported_version(opname: str, domain: str, version: int):
iter_version = version
while iter_version <= _constants.onnx_main_opset:
ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
if (domain, opname) in ops:
return iter_version
iter_version += 1
return None
def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
if domain is None or version is None:
warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
global _registry
if not is_registered_op(opname, domain, version):
raise errors.UnsupportedOperatorError(
domain, opname, version, get_op_supported_version(opname, domain, version)
)
return _registry[(domain, version)][opname]