blob: 40b527cd3528aef966db57a39b7629e60af888e5 [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from torch._C import ListType
import warnings
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from functools import wraps
# Note [Edit Symbolic Files]
# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST!
#
# - These files is ONLY for ATen operators (e.g., operators that show up in the
# trace as aten::blah). If you need to special case a primitive operator,
# look at _run_symbolic_function
# - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
# tensors are always first, then non-tensor arguments.
# - Parameter names must *exactly* match the names in VariableType.cpp, because
# dispatch is done with keyword arguments.
# - Looking for inplace ops? They're detected by the trailing underscore, and
# transparently dispatched to their non inplace versions in
# 'run_symbolic_function'. See Note [Export inplace]
#
# ----------------------------------------------------------------------------------
# A note on Tensor types
# ----------------------------------------------------------------------------------
#
# In general, we should avoid depending on the type of Tensor Values contained
# within the trace graph. However, this is sometimes unavoidable (due to ONNX
# spec requirements, etc). If you are implementing a symbolic and need Tensor
# type information, note that there are several levels of Tensor types, defined
# in aten/src/ATen/core/jit_type.h:
#
# TensorType - This is a Tensor, but we don't know anything about its
# properties (e.g. scalar type, # dims, shapes).
# Appears as `Tensor` in graph print-outs.
# ProfiledTensorType <: TensorType - Denotes a Tensor for which we know the
# concrete sizes in addition to the information
# contained in TensorTyper. This adds a sizes()
# method which can be used to retrieve the
# concrete sizes.
# @deprecated
# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
# type and number of dimensions, but not the concrete
# shapes. For example, appears as 'Float(*, *)' in
# graph print-outs. Useful accessor methods include
# dim() and scalarType()
# @deprecated
# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
# concrete sizes in addition to the information
# contained in TensorTyper. This adds a sizes()
# method which can be used to retrieve the
# concrete sizes.
#
# In general, we should prefer to rely on the least specific information possible.
# For example, not relying on tensor properties at all is better than relying
# on the number of dimensions (DimensionedTensorType) which is better than relying on
# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
# more robust to different graphs.
# ---------------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------------
# Save some builtins as locals, because we'll shadown them below
_sum = sum
def _parse_arg(value, desc):
if desc == 'none':
return value
if desc == 'v' or not _is_value(value):
return value
if value.node().kind() == 'onnx::Constant':
tval = value.node()['value']
if desc == 'i':
return int(tval)
elif desc == 'f':
return float(tval)
elif desc == 'b':
return bool(tval)
elif desc == 't':
return tval
elif desc == 'is':
return [int(v) for v in tval]
else:
raise RuntimeError("ONNX symbolic doesn't know to interpret Constant node")
elif value.node().kind() == 'prim::ListConstruct':
if desc == 'is':
for v in value.node().inputs():
if v.node().kind() != 'onnx::Constant':
raise RuntimeError("Failed to export an ONNX attribute, "
"since it's not constant, please try to make "
"things (e.g., kernel size) static if possible")
return [int(v.node()['value']) for v in value.node().inputs()]
else:
raise RuntimeError("ONNX symbolic doesn't know to interpret ListConstruct node")
raise RuntimeError("Unexpected node type: {}".format(value.node().kind()))
def _maybe_get_const(value, desc):
if _is_value(value) and value.node().kind() == 'onnx::Constant':
return _parse_arg(value, desc)
return value
def _maybe_get_scalar(value):
value_t = _maybe_get_const(value, 't')
if isinstance(value_t, torch.Tensor) and value_t.shape == ():
return value_t
return value
def _get_const(value, desc, arg_name):
if _is_value(value) and value.node().kind() != 'onnx::Constant':
raise RuntimeError("ONNX symbolic expected a constant value of the {} argument, got `{}`".format(arg_name, value))
return _parse_arg(value, desc)
def _unpack_list(list_value):
list_node = list_value.node()
assert list_node.kind() == "prim::ListConstruct"
return list(list_node.inputs())
# Check if list_value is output from prim::ListConstruct
# This is usually called before _unpack_list to ensure the list can be unpacked.
def _is_packed_list(list_value):
return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
def parse_args(*arg_descriptors):
def decorator(fn):
fn._arg_descriptors = arg_descriptors
def wrapper(g, *args):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
return fn(g, *args)
# In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
try:
wrapper = wraps(fn)(wrapper)
except Exception:
pass
return wrapper
return decorator
def _scalar(x):
"""Convert a scalar tensor into a Python value."""
assert x.numel() == 1
return x.item()
def _is_complete_or_dimensioned_tensor_type(tensor):
return tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType"
def _if_scalar_type_as(g, self, tensor):
"""
Convert self into the same type of tensor, as necessary.
We only support implicit casting for scalars, so we never
actually need to insert an ONNX cast operator here; just
fix up the scalar.
"""
if isinstance(self, torch._C.Value):
return self
scalar_type = tensor.type().scalarType()
if scalar_type:
ty = scalar_type.lower()
return getattr(self, ty)()
return self
def _is_value(x):
return isinstance(x, torch._C.Value)
def _is_tensor_list(x):
return x.type().isSubtypeOf(ListType.ofTensors())
def _unimplemented(op, msg):
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
def _black_list_in_opset(name):
def symbolic_fn(*args, **kwargs):
raise RuntimeError("ONNX export failed on {}, which is not implemented for opset {}. "
"Try exporting with other opset versions."
.format(name, _export_onnx_opset_version))
return symbolic_fn
def _try_get_scalar_type(*args):
for arg in args:
try:
return arg.type().scalarType()
except RuntimeError:
pass
return None
def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import _slice
return _slice(g, input, axes, starts, ends)
else:
from torch.onnx.symbolic_opset10 import _slice
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
# ---------------------------------------------------------------------
# ONNX operator version
# ---------------------------------------------------------------------
# READ ME BEFORE EDITING _default_onnx_opset_version:
#
# The variable below controls which ONNX operator set version we are
# targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
# change occurred in version 8. As long as this variable < 8, you can
# export models targeting the old behavior. However, if you bump
# this variable to 8 or later, the breaking change will take into effect:
# you MUST adjust any symbolic affected by breaking changes. The ONNX
# spec publishes a *comprehensive* list of BC-breaking changes for every
# operator revision at:
#
# https://github.com/onnx/onnx/blob/master/docs/Changelog.md
#
# Please be sure to go through and check all of our implementations here before
# increasing this number. This includes symbolic definitions NOT in this
# file, so grep for "OpName" (with quotes)
#
# Besides, opset_version can be specified in the invocation of export()
# and export_to_pretty_string(), and _export_onnx_opset_version will be set
# and the symbolic functions should check it to determine the behavior
# of the exporter.
_default_onnx_opset_version = 9
_onnx_master_opset = 10
_onnx_stable_opsets = [7, 8, 9, 10]
_export_onnx_opset_version = _default_onnx_opset_version
def _set_opset_version(opset_version):
global _export_onnx_opset_version
if opset_version == _default_onnx_opset_version:
_export_onnx_opset_version = opset_version
return
if opset_version in _onnx_stable_opsets + [_onnx_master_opset]:
_export_onnx_opset_version = opset_version
return
raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
_operator_export_type = None
def _set_operator_export_type(operator_export_type):
global _operator_export_type
_operator_export_type = operator_export_type
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'
#
# TODO: remove these once we support Type's in the JIT IR and we can once again
# use the unified toType operator
cast_pytorch_to_onnx = {
'Byte': torch.onnx.TensorProtoDataType.UINT8,
'Char': torch.onnx.TensorProtoDataType.INT8,
'Double': torch.onnx.TensorProtoDataType.DOUBLE,
'Float': torch.onnx.TensorProtoDataType.FLOAT,
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
'Int': torch.onnx.TensorProtoDataType.INT32,
'Long': torch.onnx.TensorProtoDataType.INT64,
'Short': torch.onnx.TensorProtoDataType.INT16,
'Bool': torch.onnx.TensorProtoDataType.BOOL,
'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64,
'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128,
'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED,
}
scalar_name_to_pytorch = {
'uint8_t': 'Byte',
'int8_t': 'Char',
'double': 'Double',
'float': 'Float',
'half': 'Half',
'int': 'Int',
'int64_t': 'Long',
'int16_t': 'Short',
'bool': 'Bool',
'complex64': '',
'complex128': ''
}
# This indicates each scalar type's corresponding
# torch type. Related source:
# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
scalar_type_to_pytorch_type = [
torch.uint8, # 0
torch.int8, # 1
torch.short, # 2
torch.int, # 3
torch.int64, # 4
torch.half, # 5
torch.float, # 6
torch.double, # 7
torch.complex64, # 9
torch.complex128, # 10
torch.bool, # 11
]
def _cast_func_template(to_i, g, input, non_blocking):
return g.op("Cast", input, to_i=to_i)
scalar_type_to_onnx = [
cast_pytorch_to_onnx["Byte"],
cast_pytorch_to_onnx["Char"],
cast_pytorch_to_onnx["Short"],
cast_pytorch_to_onnx["Int"],
cast_pytorch_to_onnx["Long"],
cast_pytorch_to_onnx["Half"],
cast_pytorch_to_onnx["Float"],
cast_pytorch_to_onnx["Double"],
cast_pytorch_to_onnx["Undefined"],
cast_pytorch_to_onnx["ComplexFloat"],
cast_pytorch_to_onnx["ComplexDouble"],
cast_pytorch_to_onnx["Bool"],
]