blob: f513c866c093a760a578a366ab30212dee075e82 [file] [log] [blame]
import torch
from torch._C import ListType, OptionalType
from torch.nn.modules.utils import _single, _pair, _triple
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from functools import partial
from functools import wraps
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
from typing import Optional
from sys import maxsize as maxsize
import numpy
import math
import warnings
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 9
# Opset 9 is supported by ONNX release 1.4.1
# release on 01/23/19
# Note [Pointwise by scalar]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# What happens if you add a tensor with a constant (e.g., x + 2)? There are
# some moving parts to implementing the ONNX translation in this case:
#
# - By the time we get the scalar in a symbolic function here, it is no longer
# a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
# want it to be a zero dim tensor but this change has not happened yet.)
# However, the type of this scalar is *exactly* what the user wrote in
# Python, which may not match the tensor it is being added to. PyTorch
# will do implicit conversions on scalars; however, ONNX will not, so
# we must do the conversion ourselves. This is what _if_scalar_type_as
# does.
#
# - Dispatch to these functions takes advantage an outrageous coincidence
# between the tensor and scalar name. When we add two tensors together,
# you get the dispatch:
#
# add(*[self, other], **{"alpha": alpha})
#
# When you add a tensor and a scalar, you get the dispatch:
#
# add(*[self], **{"other": other, "alpha": alpha})
#
# By having the argument name line up with the name of the scalar attribute
# if it exists, we can write a single function for both overloads.
#
# used to represent "missing" optional inputs
def unused(g):
n = g.op("prim::Constant")
n.setType(OptionalType.ofTensor())
return n
def _shape_as_tensor(g, input):
return g.op("Shape", input)
def _reshape_from_tensor(g, input, shape):
if (isinstance(shape, list)):
shape = g.op("Concat", *shape, axis_i=0)
return g.op("Reshape", input, shape)
def reshape(g, self, shape):
return view(g, self, shape)
def reshape_as(g, self, other):
shape = g.op("Shape", other)
return reshape(g, self, shape)
def add(g, self, other, alpha=None):
if sym_help._is_value(self) and sym_help._is_tensor_list(self):
return sym_help._onnx_opset_unsupported_detailed("Add", 9, 11, "Add between list of tensors not supported")
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("add", "alpha != 1")
return g.op("Add", self, other)
def sub(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("sub", "alpha != 1")
return g.op("Sub", self, other)
def rsub(g, self, other, alpha=None):
return sub(g, other, self, alpha=alpha)
def mul(g, self, other):
return g.op("Mul", self, other)
def div(g, self, other, *args):
if len(args) == 0:
return true_divide(g, self, other)
else:
return _div_rounding_mode(g, self, other, *args)
@parse_args("v", "v", "s")
def _div_rounding_mode(g, self, other, rounding_mode):
if rounding_mode is None:
return true_divide(g, self, other)
elif rounding_mode == "floor":
return _floor_divide(g, self, other)
elif rounding_mode == "trunc":
return _trunc_divide(g, self, other)
else:
raise RuntimeError(f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"')
def _trunc_divide(g, self, other):
out = g.op("Div", self, other)
# the correct operation is truncate, which is not supported in ONNX,
# we cannot call floor since it will behave differently for negative numbers
# (eg. -0.1 should become -0 )
# - if scalar_type information are not available, assume that
# we need to call floor (treat as float)
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx["Long"])
# Matching PyTorch's behavior:
# - if self is fp the output's type is self's type
# - if self is not fp and other is fp, the output is of type "Float"
# - self is not fp and other is not fp, the output's type is self's output type
# - the output type defaults to Float
scalar_type = self.type().scalarType()
if scalar_type is not None:
if not sym_help._is_fp(self) and \
other.type().scalarType() is not None and \
sym_help._is_fp(other):
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx["Float"])
else:
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx[scalar_type])
else:
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx["Float"])
return out
def _floor_divide(g, self, other):
if sym_help._is_fp(self) or sym_help._is_fp(other):
out = true_divide(g, self, other)
return g.op("Floor", out)
else:
# Integer division does trunction rounding
div = g.op("Div", self, other)
# Division is negative if: self < 0 != other < 0
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
negative = g.op("Xor",
g.op("Less", self, zero),
g.op("Less", other, zero))
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
mod = g.op("Sub", self, g.op("Mul", div, other))
fixup_mask = g.op("And", negative,
g.op("Not", g.op("Equal", mod, zero)))
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
fixup = g.op("Sub", div, one)
return g.op("Where", fixup_mask, fixup, div)
def floor_divide(g, self, other):
# Deprecated behavior, floor_divide actually truncates
return _trunc_divide(g, self, other)
def floordiv(g, self, other):
return floor_divide(g, self, other)
# Division where both inputs are cast to floating types
# If both inputs are floating, performs div as usual
# If only one input is a floating type, the other input is cast to its type
# If neither input is a floating type, both inputs are cast to the default scalar type
def true_divide(g, self, other):
# Case 1: both values are floating
# Performs div as usual
if sym_help._is_fp(self) and sym_help._is_fp(other):
return g.op("Div", self, other)
# Case 2: self is floating, other is not
# Casts other to self's dtype
if sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
return g.op("Div", self, other)
# Case 3: other is floating, self is not
# Casts self to other's dtype
if sym_help._is_fp(other):
self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other.type().scalarType()])
return g.op("Div", self, other)
# Case 4: neither is floating
# Casts both inputs to the default scalar type
scalar_type = torch.get_default_dtype()
onnx_scalar_type = sym_help.cast_pytorch_to_onnx["Float"]
assert scalar_type is torch.float or scalar_type is torch.double
if torch.get_default_dtype() is torch.double:
onnx_scalar_type = sym_help.cast_pytorch_to_onnx["Double"]
self = g.op("Cast", self, to_i=onnx_scalar_type)
other = g.op("Cast", other, to_i=onnx_scalar_type)
return g.op("Div", self, other)
def reciprocal(g, self):
return g.op("Div", torch.ones(1), self)
@parse_args("v", "i")
def cat(g, tensor_list, dim):
tensors = sym_help._unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
@parse_args("v", "i")
def stack(g, tensor_list, dim):
unsqueezed = [sym_help._unsqueeze_helper(g, t, [dim]) for t in sym_help._unpack_list(tensor_list)]
return g.op("Concat", *unsqueezed, axis_i=dim)
def _list(g, self):
return self
def mm(g, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
C = g.op("Constant", value_t=torch.tensor([1]))
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
def bmm(g, self, other):
return g.op("MatMul", self, other)
def matmul(g, self, other):
return g.op("MatMul", self, other)
@parse_args("v", "v", "v", "t", "t")
def addmm(g, self, mat1, mat2, beta, alpha):
dtype = None
self_dtype = sym_help._try_get_scalar_type(self)
mat1_dtype = sym_help._try_get_scalar_type(mat1)
mat2_dtype = sym_help._try_get_scalar_type(mat2)
if self_dtype is not None:
dtype = self_dtype
elif mat1_dtype is not None:
dtype = mat1_dtype
elif mat2_dtype is not None:
dtype = mat2_dtype
mat1_rank = sym_help._get_tensor_rank(mat1)
mat2_rank = sym_help._get_tensor_rank(mat2)
def isNotNoneAnd(v, u):
return v is not None and v != u
if dtype is not None and (isNotNoneAnd(mat1_rank, 2) or isNotNoneAnd(mat2_rank, 2)):
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
dtype = sym_help.scalar_type_to_pytorch_type[dtype]
res1 = g.op("MatMul", mat1, mat2)
res2 = self
alpha = sym_help._scalar(alpha)
beta = sym_help._scalar(beta)
if alpha != 1:
alpha = g.op("Constant",
value_t=torch.tensor(alpha, dtype=dtype))
res1 = g.op("Mul", res1, alpha)
if beta != 1:
beta = g.op("Constant",
value_t=torch.tensor(sym_help._scalar(beta), dtype=dtype))
res2 = g.op("Mul", res2, beta)
return g.op("Add", res1, res2)
return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
def neg(g, self):
return g.op("Neg", self)
def sqrt(g, self):
return g.op("Sqrt", self)
def rsqrt(g, self):
return g.op("Div", sym_help._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self))
def tanh(g, self):
return g.op("Tanh", self)
def sin(g, self):
return g.op("Sin", self)
def cos(g, self):
return g.op("Cos", self)
def tan(g, self):
return g.op("Tan", self)
def asin(g, self):
return g.op("Asin", self)
def acos(g, self):
return g.op("Acos", self)
def atan(g, self):
return g.op("Atan", self)
def sigmoid(g, self):
return g.op("Sigmoid", self)
def sign(g, self):
return g.op("Sign", self)
def _slice(g, input, axes, starts, ends):
assert len(starts) == len(ends)
if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807:
return input
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
def _maybe_cast_reduce_op_input(g, self):
dtype = self.type().scalarType()
# This check only covers traced modules where dtype is present
if dtype is not None:
# pytorch reduce-ops cast all other integral types to int64
if not sym_help._is_fp(self) and not (dtype == "Long"):
self = _cast_Long(g, self, False) # type: ignore[name-defined]
return self
def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
def symbolic(g, self, dim=None, keepdim=None):
self = _maybe_cast_reduce_op_input(g, self)
if dim is None:
# all-reduce path
return sym_help._handle_reduce_dim_none(g, self, onnx_op_name)
else:
# dim-reduce path
desc = "is" if allow_multi_dim_support else "i"
dim, keepdim = sym_help._get_const(dim, desc, "dim"), sym_help._get_const(keepdim, "i", "keepdim")
dim_list = dim if allow_multi_dim_support else [dim]
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
return symbolic
def overload_by_arg_count(fn):
@wraps(fn)
def wrapper(g, *args):
overloads = fn(g, *args)
last_exception = None
for overload in overloads:
arg_descriptors = overload._arg_descriptors
if len(arg_descriptors) == len(args):
return overload(g, *args)
raise NotImplementedError("Unknown aten::{} signature".format(fn.__name__))
return wrapper
def _reduce_with_dtype(onnx_op, name, allow_multi_dim_support=True):
symbolic = _reduce_op_symbolic(onnx_op, allow_multi_dim_support=allow_multi_dim_support)
@overload_by_arg_count
def reduce(g, *args, **kwargs):
@parse_args("v", "none")
def reduce_nodim(g, self, dtype):
if dtype.node().kind() != "prim::Constant":
return _unimplemented(name, "dtype")
return symbolic(g, self)
dim_desc = "is" if allow_multi_dim_support else "i"
@parse_args("v", dim_desc, "i", "none")
def reduce_dim(g, self, dim, keepdim, dtype):
if dtype.node().kind() != "prim::Constant":
return _unimplemented(name, "dtype")
return symbolic(g, self, dim, keepdim)
return reduce_nodim, reduce_dim
return reduce
sum = _reduce_with_dtype("ReduceSum", "sum")
mean = _reduce_with_dtype("ReduceMean", "mean")
prod = _reduce_with_dtype("ReduceProd", "prod", allow_multi_dim_support=False) # torch.prod does not support multidimensional "dim"
@parse_args("v", "i", "none")
def cumsum(g, input, dim, dtype):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
if dtype.node().kind() != "prim::Constant":
return _unimplemented(name, "dtype")
return g.op("ATen", input, operator_s="cumsum", dim_i=dim)
else:
sym_help._onnx_opset_unsupported("cumsum", 9, 11)
def _sample_dirichlet(g, self, generator):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
if not sym_help._is_none(generator):
return _unimplemented("_sample_dirichlet",
"We are not able to export generator")
return g.op("ATen", self, operator_s="_sample_dirichlet")
else:
return sym_help._onnx_unsupported("_sample_dirichlet")
def _standard_gamma(g, self, generator):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
if not sym_help._is_none(generator):
return _unimplemented("_standard_gamma",
"We are not able to export generator")
return g.op("ATen", self, operator_s="_standard_gamma")
else:
return sym_help._onnx_unsupported("_standard_gamma")
def t(g, self):
return g.op("Transpose", self, perm_i=(1, 0))
def expand(g, self, size, implicit):
size = sym_help._maybe_get_const(size, "is")
if not sym_help._is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
elif sym_help._is_packed_list(size):
# Expand with -1 dim value means dim is unchanged.
# Since onnx::expand supports two-way broadcasting,
# -1 dim value can be exported to onnx as 1
size = view(g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])))
dtype = 4 # dim type is int64
ones = ones_like(g, size, dtype)
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
size = where(g, g.op("Equal", size, neg_ones), ones, size)
return g.op("Expand", self, size)
def expand_as(g, self, other):
self_t = sym_help._maybe_get_const(self, "t")
if isinstance(self_t, torch.Tensor):
orig_type = self_t.dtype
self_t = self_t.to(torch.double)
dims = []
for d in range(self_t.dim()):
if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t):
dims.append(d)
self = g.op("Constant", value_t=self_t.mean(dims).to(orig_type))
shape = g.op("Shape", other)
return g.op("Expand", self, shape)
@parse_args("v", "v", "i", "b", "v")
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
if scale_grad_by_freq and sym_help._training_mode:
raise RuntimeError("Unsupported: ONNX export of embedding with scale_grad_by_freq=True "
"for training mode. ONNX does not support scaling the gradients.")
if padding_idx >= 0 and sym_help._training_mode:
warnings.warn("Warning: ONNX export of embedding with padding_idx >= 0 "
"for training mode. "
"ONNX does not support not updating the embedding vector at padding_idx during training.")
return g.op("Gather", weight, indices)
@parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
def embedding_bag(g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset,
padding_idx):
if not sym_help._is_none(per_sample_weights):
return sym_help._onnx_unsupported("embedding_bag with per_sample_weights")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen",
embedding_matrix,
indices,
offsets,
operator_s="embedding_bag",
outputs=4,
scale_grad_by_freq_i=scale_grad_by_freq,
mode_i=mode,
sparse_i=sparse,
include_last_offset_i=include_last_offset,
padding_idx_i=padding_idx)
else:
return sym_help._onnx_unsupported("embedding_bag")
def size(g, self, dim=None):
if dim is None:
return g.op("Shape", self)
if sym_help._maybe_get_const(dim, "i") < 0:
rank = sym_help._get_tensor_rank(self)
if rank is not None:
dim = sym_help._maybe_get_const(dim, "i") + rank
dim = g.op("Constant", value_t=torch.tensor(dim))
return sym_help._size_helper(g, self, dim)
@parse_args("v", "i", "i")
def transpose(g, self, dim0, dim1):
if dim0 == dim1: # micro-optimization
return self
# NB: Transpose in ONNX is actually a Permute
rank = sym_help._get_tensor_rank(self)
if rank is not None:
axes = list(range(rank))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
else:
# if we don't have dim information we cannot
# output a permute so use ATen instead
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, operator_s="transpose", dim0_i=dim0, dim1_i=dim1)
else:
raise RuntimeError("Unsupported: ONNX export of transpose for tensor "
"of unknown rank.")
@parse_args("v", "is")
def permute(g, self, dims):
if dims == list(range(0, len(dims))):
return self
return g.op("Transpose", self, perm_i=dims)
def view(g, self, size):
size = sym_help._maybe_get_const(size, "is")
if sym_help._is_value(size):
shape = size
else:
shape = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Reshape", self, shape)
def view_as(g, self, other):
shape = g.op("Shape", other)
return g.op("Reshape", self, shape)
def prim_ConstantSplit(g, self, split_size, dim):
size = sym_help._get_tensor_dim_size(self, dim)
if size is None:
return _unimplemented("prim::ConstantSplit", "unknown dimension size")
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
# TODO: It would be better to export this as a chunk directly, as this is
# less sensitive to changes in input size.
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
# method, and use the desugared version
def prim_ConstantChunk(g, self, chunks, dim):
dim_size = sym_help._get_tensor_dim_size(self, dim)
if dim_size is None:
return _unimplemented("prim::ConstantChunk", "unknown dimension size")
split_size = (dim_size + chunks - 1) // chunks
return prim_ConstantSplit(g, self, split_size, dim)
@parse_args("v", "i", "i", "i")
def unsafe_chunk(g, self, chunks, dim, _outputs=None):
if _outputs is None:
return sym_help._onnx_opset_unsupported_detailed("unsafe_chunk", 9, 11, "Dynamic number of outputs not supported")
size = sym_help._get_tensor_dim_size(self, dim)
if size is None:
return _unimplemented("unsafe_chunk", "unknown dimension size")
split_size = (size + chunks - 1) // chunks
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
@parse_args("v", "v", "v", "i")
def split(g, self, split_size_or_sizes, dim, _outputs=None):
if not sym_help._is_split_static(split_size_or_sizes, _outputs):
return sym_help._onnx_opset_unsupported_detailed("split", 9, 11, "Dynamic number of outputs not supported")
split_val = split_size_or_sizes.node()["value"]
if split_val.dim() > 0:
return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs)
split_size = sym_help._get_const(split_size_or_sizes, "i", "split_size")
dim = sym_help._get_const(dim, "i", "dim")
size = sym_help._get_tensor_dim_size(self, dim)
if size is None:
if _outputs is not None:
size = split_size * _outputs
else:
return sym_help._onnx_opset_unsupported_detailed("split", 9, 11, "Unknown dimension size not supported")
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs)
def unsafe_split(g, self, split_size_or_sizes, dim, _outputs=None):
return split(g, self, split_size_or_sizes, dim, _outputs)
@parse_args("v", "is", "i", "i")
def split_with_sizes(g, self, split_sizes, dim, _outputs=None):
if not sym_help._is_split_static(split_sizes, _outputs):
return sym_help._onnx_opset_unsupported_detailed("split_with_sizes", 9, 11, "Dynamic number of outputs not supported")
return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs)
def unsafe_split_with_sizes(g, self, split_sizes, dim, _outputs=None):
return split_with_sizes(g, self, split_sizes, dim, _outputs)
@parse_args("v", "i", "i")
def unbind(g, self, dim=0, _outputs=None):
if _outputs is None:
return sym_help._onnx_opset_unsupported_detailed("unbind", 9, 11, "Dynamic number of outputs not supported")
outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs)
outputs = [outputs] if _outputs == 1 else outputs
squeezed_outputs = [sym_help._squeeze_helper(g, out, [dim]) for out in outputs]
return squeezed_outputs
@parse_args("v", "i", "v")
def select(g, self, dim, index):
index = sym_help._maybe_get_scalar(index)
if (not sym_help._is_value(index)) and (index < 0):
if index == -1:
end_index = 9223372036854775807
else:
end_index = index + 1
slice_node = sym_help._slice_helper(g, self, axes=[dim], starts=[index], ends=[end_index])
return sym_help._squeeze_helper(g, slice_node, [dim])
else:
return g.op("Gather", self, index, axis_i=dim)
def square(g, self):
return g.op("Mul", self, self)
def squeeze(g, self, dim=None):
if dim is None:
return g.op("Squeeze", self)
squeeze_dim = sym_help._get_const(dim, "i", "dim")
# Handle negative dims
if squeeze_dim < 0:
rank = sym_help._get_tensor_rank(self)
if rank is not None:
warnings.warn("ONNX export squeeze with negative axis " + str(squeeze_dim) +
" might cause the onnx model to be incorrect. " +
"Negative axis is not supported in ONNX. " +
"Axis is converted to " + str(squeeze_dim + rank) +
" based on input shape at export time. " +
"Passing an tensor of different rank in execution will be incorrect.")
squeeze_dim += rank
else:
return _unimplemented("squeeze", "negative axis with unknown input rank")
dim_size = sym_help._get_tensor_dim_size(self, squeeze_dim)
if dim_size is None:
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " +
"with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " +
"is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " +
"non-singleton dimensions, it is recommended to export this model using opset " +
"version 11 or higher.")
return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim])
if dim_size > 1:
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " +
"this dimension in the given input is " + str(dim_size) + ". The model will " +
"be exported without the squeeze node. If the model is intended to be used with dynamic " +
"input shapes, please use opset version 11 to " +
"export the model.")
return self
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". If the model is " +
"intended to be used with dynamic input shapes, please use opset version 11 to export the model.")
return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim])
def prelu(g, self, weight):
self_rank = sym_help._get_tensor_rank(self)
if self_rank is not None and self_rank > 2:
weight = sym_help._unsqueeze_helper(g, weight, list(range(1, self_rank - 1)))
return g.op("PRelu", self, weight)
def silu(g, input):
return g.op("Mul", input, g.op("Sigmoid", input))
def mish(g, input):
return g.op('Mul', input, g.op('Tanh', g.op('Softplus', input)))
def relu(g, input):
return g.op("Relu", input)
def relu6(g, input):
relu = g.op("Relu", input)
return clamp_max(g, relu, 6)
def ceil(g, input):
return g.op("Ceil", input)
def floor(g, input):
return g.op("Floor", input)
def _len(g, self):
sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
return sym_help._squeeze_helper(g, sz_0, [0])
@parse_args("v", "t", "t")
def threshold(g, self, threshold, value):
# See Note [Export inplace]
if sym_help._scalar(threshold) != 0:
return _unimplemented("threshold", "non-zero threshold")
if sym_help._scalar(value) != 0:
return _unimplemented("threshold", "non-zero value")
return g.op("Relu", self)
def leaky_relu(g, input, negative_slope, inplace=False):
negative_slope = sym_help._get_const(negative_slope, "t", "negative_slope")
# See Note [Export inplace]
# TODO: Talk to ONNX about unconditional cast of scalar to float
return g.op("LeakyRelu", input, alpha_f=sym_help._scalar(negative_slope))
@parse_args("v", "i")
def glu(g, input, dim):
dim_size = sym_help._get_tensor_dim_size(input, dim)
if dim_size is not None:
assert dim_size % 2 == 0
first, second = g.op("Split", input, axis_i=dim, outputs=2)
return g.op("Mul", first, g.op("Sigmoid", second))
@parse_args("v", "i", "none")
def softmax(g, input, dim, dtype=None):
# Softmax does normalization at vector level.
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
# Thus dim and axis have different meanings.
# PyTorch slices the input tensor into vectors along the `dim`-th dimension.
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
# If input is a 2 x 3 tensor:
# input = [[1.0, 1.0, 1.0],
# [1.0, 1,0, 1,0]]
# with dim = 0, the result is:
# result = [[0.5, 0.5, 0.5],
# [0.5, 0.5, 0.5]]
# with axis = 0, the result is:
# result = [[0.167, 0.167, 0.167],
# [0.167, 0.167, 0.167]]
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
# So use softmax when dim and axis both equal to ndim - 1,
# otherwise transpose the input to put the vectors to be normalized to the last dimension.
# When input rank is not known at export time we compute softmax using a subgraph
# with other operators
input_dim = sym_help._get_tensor_rank(input)
if input_dim is not None:
# TODO: remove this as onnx opset 11 spec allows negative axes
if dim < 0:
dim = input_dim + dim
is_transpose_required = (input_dim != dim + 1)
if is_transpose_required:
axes = list(range(input_dim))
axes[dim], axes[-1] = axes[-1], axes[dim]
input = g.op("Transpose", input, perm_i=axes)
dim = input_dim - 1
softmax = g.op("Softmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
parsed_dtype = sym_help._get_const(dtype, "i", "dtype")
softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
if is_transpose_required:
softmax = g.op("Transpose", softmax, perm_i=axes)
return softmax
# Apply max normalization.
input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1))
exp = g.op("Exp", input)
sum = sym_help._reducesum_helper(g, exp, axes_i=[dim])
softmax = g.op("Div", exp, sum)
if dtype and dtype.node().kind() != "prim::Constant":
parsed_dtype = sym_help._get_const(dtype, "i", "dtype")
softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
@parse_args("v", "t", "v")
def softplus(g, self, beta, threshold):
if beta != 1:
return _unimplemented("beta", "has to be 1")
return g.op("Softplus", self)
def get_pool_ceil_padding(input, kernel_size, stride, padding):
sizes = sym_help._get_tensor_sizes(input)
dim = sizes[-len(padding):] if sizes is not None else None
if dim is None or any([i is None for i in dim]):
return _unimplemented(name, "input size not accessible")
ceiled_output_dim = [int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1
for i in range(0, len(padding))]
# ensure last pooling starts inside
ceiled_output_dim = [ceiled_output_dim[i] - 1
if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
else ceiled_output_dim[i]
for i in range(0, len(ceiled_output_dim))]
padding_ceil = [0
if (stride[i] == 1)
else
(kernel_size[i] - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)))
for i in range(0, len(padding))]
# ensure padding is not > kernel_size
padding_ceil = [(int(padding_ceil[i]) if padding_ceil[i] < kernel_size[i] - 1 else int(kernel_size[i] - 1))
if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
else
int(padding_ceil[i])
for i in range(0, len(padding_ceil))]
return padding_ceil
def _max_pool(name, tuple_fn, ndims, return_indices):
@parse_args("v", "is", "is", "is", "is", "i")
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if set(tuple_fn(dilation)) != {1}:
return _unimplemented(name, "dilation")
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
kwargs = {
"kernel_shape_i": tuple_fn(kernel_size),
"pads_i": padding,
"strides_i": tuple_fn(stride),
}
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and substract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
if return_indices:
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
_, flattened_indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=[1 for _ in range(ndims)],
strides_i=[1 for _ in range(ndims)])
# convert indices to have non-flattened indices values
s = sym_help._slice_helper(g, flattened_indices, axes=[2 + i for i in range(ndims)],
starts=tuple_fn(0), ends=tuple_fn(1))
indices = sub(g, indices, s)
return r, indices
else:
r = g.op("MaxPool", input, outputs=1, **kwargs)
return r
return symbolic_fn
max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
def _avg_pool(name, tuple_fn):
@parse_args("v", "is", "is", "is", "i", "i", "none")
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None):
if not stride:
stride = kernel_size
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name)
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
if count_include_pad:
input = g.op("Pad", input,
pads_i=((0,) * 2 + padding) * 2,
mode_s="constant",
value_f=0.)
padding = (0,) * len(padding)
if ceil_mode:
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
output = g.op("AveragePool", input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding)
return output
return symbolic_fn
avg_pool1d = _avg_pool("avg_pool1d", _single)
avg_pool2d = _avg_pool("avg_pool2d", _pair)
avg_pool3d = _avg_pool("avg_pool3d", _triple)
def _adaptive_pool(name, type, tuple_fn, fn=None):
def symbolic_fn(g, input, output_size):
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
# by executing a GlobalPool.
# It is also supported for cases where the output size is a factor of the input size.
# For these cases the stride and kernel size are uniform along all the indices of
# the same dimension, which makes it possible to export it to ONNX.
# for MaxPool, GlobalMaxPool does not return indices,
# so we try using max_poolxd_with_indices, and if it is not possible
# (input is not a complete tensor or output size not factor of input size)
# then we call GlobalAveragePool and return None for the indices
try:
output_size = _parse_arg(output_size, "is")
except Exception:
return sym_help._onnx_unsupported("adaptive pooling, since output_size is not constant.")
if output_size == [1] * len(output_size) and type == "AveragePool":
return g.op("GlobalAveragePool", input)
sizes = sym_help._get_tensor_sizes(input)
try:
dim = sizes[2:]
except Exception:
dim = None
if dim is None or any([i is None for i in dim]):
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return _unimplemented(name, "input size not accessible")
# verify if output size % input size = 0 for all dim
mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
if mod != [0] * len(mod):
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return _unimplemented(name, "output size that are not factor of input size")
else:
return sym_help._onnx_unsupported(name + ", since output size is not factor of input size")
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == "MaxPool":
return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
output = g.op(type, input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(k))
return output
return symbolic_fn
adaptive_avg_pool1d = _adaptive_pool("adaptive_avg_pool1d", "AveragePool", _single)
adaptive_avg_pool2d = _adaptive_pool("adaptive_avg_pool2d", "AveragePool", _pair)
adaptive_avg_pool3d = _adaptive_pool("adaptive_avg_pool3d", "AveragePool", _triple)
adaptive_max_pool1d = _adaptive_pool("adaptive_max_pool1d", "MaxPool", _single, max_pool1d_with_indices)
adaptive_max_pool2d = _adaptive_pool("adaptive_max_pool2d", "MaxPool", _pair, max_pool2d_with_indices)
adaptive_max_pool3d = _adaptive_pool("adaptive_max_pool3d", "MaxPool", _triple, max_pool3d_with_indices)
# Generate paddings in ONNX order based on pad in pytorch.
# Args:
# dim: the dimension of the tensor.
# pad: the paddings in pytorch.
# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
def _prepare_onnx_paddings(dim, pad):
assert isinstance(dim, int)
# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# assume zero-dimensions in the beginning
paddings = list(pad[:]) + [0] * (dim * 2 - len(pad))
# reverse order and collate first beginnings and then ends
paddings = paddings[-2::-2] + paddings[-1::-2]
return paddings
def _convert_padding_node(padding):
padding = sym_help._maybe_get_const(padding, "is")
if sym_help._is_value(padding) and sym_help._is_packed_list(padding):
input_list = sym_help._unpack_list(padding)
try:
padding = [sym_help._get_const(v, "i", "padding") for v in input_list]
except Exception:
return sym_help._onnx_opset_unsupported_detailed("Pad", 9, 11, "The sizes of the padding must be constant")
return padding
def constant_pad_nd(g, input, padding, value):
mode = "constant"
try:
value = sym_help._get_const(value, "f", "value")
except Exception:
return sym_help._onnx_opset_unsupported_detailed("Pad", 9, 11, "The value for the padding must be constant")
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
def reflection_pad(g, input, padding):
mode = "reflect"
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
def replication_pad(g, input, padding):
mode = "edge"
padding = _convert_padding_node(padding)
paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
replication_pad1d = replication_pad
replication_pad2d = replication_pad
replication_pad3d = replication_pad
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)
sym_help._interpolate_warning(interpolate_mode)
align_corners = sym_help._maybe_get_scalar(align_corners)
if align_corners:
return _unimplemented(name, "align_corners == True")
if scales is None:
scales = sym_help._interpolate_size_to_scales(g, input, output_size, dim)
return g.op("Upsample", input, scales, mode_s=interpolate_mode)
return symbolic_fn
upsample_nearest1d = _interpolate("upsample_nearest1d", 3, "nearest")
upsample_nearest2d = _interpolate("upsample_nearest2d", 4, "nearest")
upsample_nearest3d = _interpolate("upsample_nearest3d", 5, "nearest")
upsample_linear1d = _interpolate("upsample_linear1d", 3, "linear")
upsample_bilinear2d = _interpolate("upsample_bilinear2d", 4, "linear")
upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
def __interpolate(g, input, size, scale_factor, mode , align_corners, recompute_scale_factor):
scales, mode = sym_help._interpolate_get_scales_and_mode(g, input, size, scale_factor,
mode , align_corners)
return g.op("Upsample", input, scales, mode_s=mode)
@parse_args("v")
def bitwise_not(g, inp):
if inp.type().scalarType() != "Bool":
return _unimplemented("bitwise_not", "non-bool tensor")
return g.op("Not", inp)
def wrap_logical_op_with_cast_to(to_type):
def decorator(fn):
def wrap_with_cast(g, input, other):
return g.op("Cast", fn(g, input, other), to_i=sym_help.cast_pytorch_to_onnx[to_type])
return wrap_with_cast
return decorator
def wrap_logical_op_with_cast_to_and_from(to_type):
def decorator(fn):
def wrap_with_cast(g, input, other):
to_cast_func = globals()["_cast_{}".format(to_type)]
from_cast_func = wrap_logical_op_with_cast_to(input.type().scalarType())(fn)
return from_cast_func(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
return wrap_with_cast
return decorator
def wrap_logical_op_with_negation(func):
def wrap_with_not(g, input, other):
return g.op("Not", func(g, input, other))
return wrap_with_not
def __not_(g, self):
return g.op("Not", self)
def eq(g, self, other):
return g.op("Equal", self, other)
@wrap_logical_op_with_negation
def ne(g, self, other):
return g.op("Equal", self, other)
def gt(g, input, other):
return gt_impl(g, input, other)
def gt_impl(g, input, other):
if input.type().scalarType() is not None and input.type().scalarType() == "Bool" and \
other.type().scalarType() is not None and other.type().scalarType() == "Bool":
input = g.op("Cast", input, to_i=sym_help.cast_pytorch_to_onnx["Int"])
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx["Int"])
return g.op("Greater", input, other)
def lt(g, input, other):
return lt_impl(g, input, other)
def lt_impl(g, input, other):
if input.type().scalarType() is not None and input.type().scalarType() == "Bool" and \
other.type().scalarType() is not None and other.type().scalarType() == "Bool":
input = g.op("Cast", input, to_i=sym_help.cast_pytorch_to_onnx["Int"])
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx["Int"])
return g.op("Less", input, other)
@wrap_logical_op_with_negation
def ge(g, input, other):
return lt_impl(g, input, other)
@wrap_logical_op_with_negation
def le(g, input, other):
return gt_impl(g, input, other)
@wrap_logical_op_with_cast_to_and_from("Bool")
def __and_(g, input, other):
return g.op("And", input, other)
@wrap_logical_op_with_cast_to_and_from("Bool")
def __or_(g, input, other):
return g.op("Or", input, other)
@wrap_logical_op_with_cast_to_and_from("Bool")
def logical_and(g, input, other):
return g.op("And", input, other)
@wrap_logical_op_with_cast_to_and_from("Bool")
def logical_or(g, input, other):
return g.op("Or", input, other)
@wrap_logical_op_with_cast_to_and_from("Bool")
def logical_xor(g, input, other):
return g.op("Xor", input, other)
def __rshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if other.type().scalarType() != self.type().scalarType():
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx["Float"])
two_pow = g.op("Pow", two, other)
two_pow = g.op("Cast", two_pow, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
rshift = g.op("Div", self, two_pow)
return rshift
def __lshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if other.type().scalarType() != self.type().scalarType():
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx["Float"])
two_pow = g.op("Pow", two, other)
two_pow = g.op("Cast", two_pow, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
lshift = g.op("Mul", self, two_pow)
return lshift
@parse_args("v", "v", "v", "i")
def where(g, condition, self=None, other=None, _outputs=None):
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
if condition.type().scalarType() != "Bool":
condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx["Bool"])
if self is None:
condition = torch.onnx.symbolic_opset9.nonzero(g, condition)
return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs)
return g.op("Where", condition, self, other)
@parse_args("v", "i", "none")
def log_softmax(g, input, dim, dtype=None):
# PyTorch dim and ONNX axis have different meanings.
# See Softmax comment for details.
# TODO: remove this as onnx opset 11 spec allows negative axes
input_dim = sym_help._get_tensor_rank(input)
if input_dim is None:
return _unimplemented("dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.")
if dim < 0:
dim = input_dim + dim
is_transpose_required = (input_dim != dim + 1)
# ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases.
if is_transpose_required:
axes = list(range(input_dim))
axes[dim], axes[-1] = axes[-1], axes[dim]
input = g.op("Transpose", input, perm_i=axes)
dim = input_dim - 1
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype and dtype.node().kind() != "prim::Constant":
parsed_dtype = sym_help._get_const(dtype, "i", "dtype")
return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
if is_transpose_required:
return_op = g.op("Transpose", return_op, perm_i=axes)
return return_op
@parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i")
def _convolution(g, input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32):
weight_size = sym_help._get_tensor_sizes(weight)
try:
kernel_shape = weight_size[2:]
except Exception:
kernel_shape = None
if kernel_shape is None or any([i is None for i in kernel_shape]):
raise RuntimeError("Unsupported: ONNX export of convolution for kernel "
"of unknown shape.")
args = [input, weight]
# ONNX only supports 1D bias
if not sym_help._is_none(bias) and sym_help._get_tensor_rank(bias) == 1:
args.append(bias)
kwargs = {"kernel_shape_i": weight_size[2:],
"strides_i": stride,
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
# symmetric padding
"pads_i": padding + padding,
"dilations_i": dilation,
"group_i": groups}
if any(o != 0 for o in output_padding):
# ONNX supports both output_shape and output_padding. they are equivalent expressive.
# output_padding is more straightforward, so we use it here.
# output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
assert transposed
assert len(stride) == len(output_padding)
kwargs["output_padding_i"] = output_padding
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
if not sym_help._is_none(bias) and sym_help._get_tensor_rank(bias) != 1:
return g.op("Add", n, bias)
else:
return n
@parse_args("v", "v", "v", "is", "is", "is", "i")
def conv1d(g, input, weight, bias, stride, padding, dilation, groups):
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None)
@parse_args("v", "v", "v", "is", "is", "is", "i")
def conv2d(g, input, weight, bias, stride, padding, dilation, groups):
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None)
@parse_args("v", "v", "v", "is", "is", "is", "i")
def conv3d(g, input, weight, bias, stride, padding, dilation, groups):
return _convolution(g, input, weight, bias, stride, padding, dilation, False, (), groups, None, None, None, None)
@parse_args("v", "v", "v", "is", "is", "is", "i", "is")
def conv_transpose1d(g, input, weight, bias, stride, padding, output_padding, groups, dilation):
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None)
@parse_args("v", "v", "v", "is", "is", "is", "i", "is")
def conv_transpose2d(g, input, weight, bias, stride, padding, output_padding, groups, dilation):
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None)
@parse_args("v", "v", "v", "is", "is", "is", "i", "is")
def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, groups, dilation):
return _convolution(g, input, weight, bias, stride, padding, dilation, True, output_padding, groups, None, None, None, None)
@parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
sym_help.assert_training_mode(training, "batch_norm")
batch_size = sym_help._get_tensor_dim_size(input, 0)
channel_size = sym_help._get_tensor_dim_size(input, 1)
if weight is None or sym_help._is_none(weight):
if channel_size is None:
raise RuntimeError("Unsupported: ONNX export of batch_norm for unknown "
"channel size.")
weight_value = torch.tensor([1.] * channel_size).type(
"torch." + input.type().scalarType() + "Tensor")
weight = g.op("Constant", value_t=weight_value)
if bias is None or sym_help._is_none(bias):
if channel_size is None:
raise RuntimeError("Unsupported: ONNX export of batch_norm for unknown "
"channel size.")
bias_value = torch.tensor([0.] * channel_size).type(
"torch." + input.type().scalarType() + "Tensor")
bias = g.op("Constant", value_t=bias_value)
# If track_running_stats is set to False batch statistics are instead used during evaluation time
if running_mean is None or sym_help._is_none(running_mean) or running_var is None or sym_help._is_none(running_var):
assert batch_size is not None and channel_size is not None
reshape_in = g.op("Reshape", input,
g.op("Constant", value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64)))
trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1])
running_var, running_mean = _var_mean(g, trans_in,
g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)),
False, False)
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1 if not sym_help._training_mode else 5)
if not sym_help._training_mode:
return out
else:
res, new_running_mean, new_running_var, saved_mean, saved_var = out
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
return res
@parse_args("v", "is", "v", "v", "f", "i")
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, weight, bias, normalized_shape_i=normalized_shape,
eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm")
axes = [-i for i in range(len(normalized_shape), 0, -1)]
two_cst = sym_help._generate_wrapped_number(g, 2.)
eps_cst = sym_help._generate_wrapped_number(g, eps)
mean = g.op("ReduceMean", input, axes_i=axes)
numerator = sub(g, input, mean)
# variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula
variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes)
denominator = sqrt(g, add(g, variance, eps_cst))
layer_norm = g.op("Div", numerator, denominator)
if not (weight is None or sym_help._is_none(weight)):
layer_norm = mul(g, layer_norm, weight)
if not (bias is None or sym_help._is_none(bias)):
layer_norm = add(g, layer_norm, bias)
return layer_norm
@parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
if running_mean is None or sym_help._is_none(running_mean) or running_var is None or sym_help._is_none(running_var):
channel_size = sym_help._get_tensor_dim_size(input, 1)
if weight is None or sym_help._is_none(weight):
if channel_size is None:
raise RuntimeError("Unsupported: ONNX export of instance_norm for unknown "
"channel size.")
weight_value = torch.tensor([1.] * channel_size).type(
"torch." + input.type().scalarType() + "Tensor")
weight = g.op("Constant", value_t=weight_value)
if bias is None or sym_help._is_none(bias):
if channel_size is None:
raise RuntimeError("Unsupported: ONNX export of instance_norm for unknown "
"channel size.")
bias_value = torch.tensor([0.] * channel_size).type(
"torch." + input.type().scalarType() + "Tensor")
bias = g.op("Constant", value_t=bias_value)
return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
else:
# Now if track_running_stats is set to True it would get the same result with batchnorm. The PyTorch internal
# implementation of instance_norm may have a problem with running_mean and running_var repeat, will open a github issue.
return batch_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled)
@parse_args("v", "i", "i", "i")
def unfold(g, input, dimension, size, step):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
sizes = sym_help._get_tensor_sizes(input)
try:
sizedim = sizes[dimension]
except Exception:
sizedim = None
if sizedim is not None:
low_indices = range(0, sizedim, step)
hi_indices = range(size, sizedim + 1, step)
stack = [sym_help._slice_helper(g, input, axes=[dimension], starts=[low], ends=[hi])
for low, hi in zip(low_indices, hi_indices)]
ndim = len(sizes)
perm = list(range(0, ndim))
perm.append(perm.pop(dimension))
unsqueeze = [sym_help._unsqueeze_helper(g, g.op("Transpose", t, perm_i=perm), [dimension]) for t in stack]
return g.op("Concat", *unsqueeze, axis_i=dimension)
else:
return _unimplemented("Unfold", "input size not accessible")
@parse_args("v", "t", "t", "t")
def elu(g, input, alpha, scale, input_scale):
if scale and scale != 1.:
return _unimplemented("scale", "does not support scale in Elu")
if input_scale and input_scale != 1.:
return _unimplemented("input_scale", "does not support input_scale in Elu")
# See Note [Export inplace]
return g.op("Elu", input, alpha_f=sym_help._scalar(alpha))
def selu(g, input):
return g.op("Selu", input)
@parse_args("v", "i", "v")
def index_select(g, self, dim, index):
# In case of a scalar index, index_select returns a tensor with the same rank as the input.
# To match this behavior in ONNX, we make index a 1D tensor so that the following gather
# also produces a tensor with the same rank as the input.
return sym_help._select_helper(g, self, dim, index)
def index_put(g, self, indices_list_value, values, accumulate):
if sym_help._is_packed_list(indices_list_value):
indices_list = sym_help._unpack_list(indices_list_value)
else:
indices_list = [indices_list_value]
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
args = [self] + indices_list + [values, accumulate]
return g.op("ATen", *args, operator_s="index_put")
accumulate = sym_help._parse_arg(accumulate, "b")
if len(indices_list) == 0:
if accumulate:
return add(g, self, values)
else:
return values
else:
sym_help._onnx_opset_unsupported("index_put", 9, 11)
def index_fill(g, self, dim, index, value):
dim_value = sym_help._parse_arg(dim, "i")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill")
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, self)
expanded_value = expand(g, value, expanded_index_shape, None)
return scatter(g, self, dim, expanded_index, expanded_value)
def index_copy(g, self, dim, index, source):
dim_value = sym_help._parse_arg(dim, "i")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, index, source, dim_i=dim_value, operator_s="index_copy")
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
return scatter(g, self, dim, expanded_index, source)
def type_as(g, self, other):
self_dtype = sym_help._try_get_scalar_type(self)
other_dtype = sym_help._try_get_scalar_type(other)
if self_dtype == other_dtype and self_dtype is not None:
return self
if other_dtype is not None:
return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other_dtype])
else:
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
# We don't know the type of other, bail by emitting ATen
return g.op("ATen", self, other, operator_s="type_as")
else:
raise RuntimeError('Unsupported: ONNX export of type_as for tensor '
'of unknown dtype. Please check if the dtype of the '
'parameter passed to the type_as function is correct.')
@parse_args("v", "v", "i", "f")
def cosine_similarity(g, x1, x2, dim, eps):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", x1, x2, dim_i=dim, eps_f=eps, operator_s="cosine_similarity")
else:
return sym_help._onnx_unsupported("cosine_similarity")
# ignore clone operators that are inserted by PyTorch autograd
def clone(g, input, unused_memory_format):
return input
def abs(g, self):
return g.op("Abs", self)
def log(g, self):
return g.op("Log", self)
def log1p(g, self):
return log(g, add(g, sym_help._if_scalar_type_as(g, torch.ones(1), self), self))
def pow(g, self, exponent):
f_dtype = self_dtype = self.type().scalarType()
if not sym_help._is_fp(self):
f_dtype = "Float"
self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[f_dtype])
if not sym_help._is_fp(exponent):
exponent = g.op("Cast", exponent, to_i=sym_help.cast_pytorch_to_onnx[f_dtype])
pow = g.op("Pow", self, exponent)
return pow
def clamp(g, self, min, max):
# min or max may be None that we need to dispatch to
# Clip separately, as ONNX does not have None syntax
if sym_help._is_none(min):
return clamp_max(g, self, max)
elif sym_help._is_none(max):
return clamp_min(g, self, min)
else:
if sym_help._is_constant(min) and sym_help._is_constant(max):
return g.op("Clip", self, min_f=_parse_arg(min, "f"), max_f=_parse_arg(max, "f"))
else:
return clamp_max(g, clamp_min(g, self, min), max)
@parse_args('v', 'v')
def clamp_min(g, self, min):
if sym_help._is_constant(min):
return g.op("Clip", self, min_f=_parse_arg(min, 'f'))
else:
dtype = self.type().scalarType()
min = g.op("Cast", min, to_i=sym_help.cast_pytorch_to_onnx[dtype])
return g.op("Max", self, min)
@parse_args("v", "v")
def clamp_max(g, self, max):
if sym_help._is_constant(max):
return g.op("Clip", self, max_f=_parse_arg(max, 'f'))
else:
dtype = self.type().scalarType()
max = g.op("Cast", max, to_i=sym_help.cast_pytorch_to_onnx[dtype])
return g.op("Min", self, max)
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
def max(g, self, dim_or_y=None, keepdim=None):
# torch.max(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMax", self, keepdims_i=0)
# torch.max(input, other)
if keepdim is None:
return g.op("Max", self, dim_or_y)
# torch.max(input, dim, keepdim)
else:
dim = sym_help._get_const(dim_or_y, "i", "dim")
keepdim = sym_help._get_const(keepdim, "i", "keepdim")
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim)
return max, indices
def min(g, self, dim_or_y=None, keepdim=None):
# torch.min(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMin", self, keepdims_i=0)
# torch.min(input, other)
if keepdim is None:
return g.op("Min", self, dim_or_y)
# torch.min(input, dim, keepdim)
else:
dim = sym_help._get_const(dim_or_y, "i", "dim")
keepdim = sym_help._get_const(keepdim, "i", "keepdim")
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim)
return min, indices
def exp(g, self):
return g.op("Exp", self)
@parse_args("v", "f", "i")
def dropout(g, input, p, train):
sym_help.assert_training_mode(train, "dropout")
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
if not sym_help._training_mode:
return input
warnings.warn("Dropout is a training op and should not be exported in inference mode. "
"Make sure to call eval() on the model, and to export it with param training=False.")
r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
return r
def _unsupported_dropout(name):
@parse_args("v", "f", "i")
def feature_dropout(g, input, p, train):
# NB: In inference mode, FeatureDropout is exported as an identity op.
if train:
return _unimplemented(name, "training mode")
return input
return feature_dropout
feature_dropout = _unsupported_dropout("feature_dropout")
alpha_dropout = _unsupported_dropout("alpha_dropout")
feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout")
# See Note [Export inplace]
dropout_ = dropout
feature_dropout_ = feature_dropout
alpha_dropout_ = alpha_dropout
feature_alpha_dropout_ = feature_alpha_dropout
@parse_args("v", "t", "is", "i")
def norm(g, self, p, dim, keepdim):
if p == 1:
f = _reduce_op_symbolic("ReduceL1")
elif p == 2:
f = _reduce_op_symbolic("ReduceL2")
else:
raise RuntimeError("ONNX export only p-norms with p of 1 or 2")
return f(g, self, dim=dim, keepdim=keepdim)
@parse_args("v", "v", "v", "i")
def conv_tbc(g, input, weight, bias, pad):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
else:
# input must have 3 dimensions, see:
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
# input = (time, batch, in_channels)
# weight = (kernel_width, in_channels, out_channels)
# bias = (out_channels,)
input = g.op("Transpose", input, perm_i=[1, 2, 0])
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
return g.op("Transpose", conv, perm_i=[2, 0, 1])
@parse_args("v", "i", "i")
def _unique(g, input, sorted, return_inverse):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
return_inverse_i=return_inverse, outputs=2)
else:
return sym_help._onnx_unsupported("_unique")
@parse_args("v", "i", "i", "i")
def _unique2(g, input, sorted, return_inverse, return_counts):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, operator_s="_unique2", sorted_i=sorted,
return_inverse_i=return_inverse, return_counts_i=return_counts,
outputs=3)
else:
sym_help._onnx_opset_unsupported("_unique2", 9, 11)
for k, v in sym_help.cast_pytorch_to_onnx.items():
name = "_cast_{}".format(k)
globals()[name] = parse_args("v", "i")(partial(sym_help._cast_func_template, v))
@parse_args("v", "i", "v", "v", "v", "v")
def empty(g, sizes, dtype, layout, device, pin_memory=False, memory_format=None):
return zeros(g, sizes, dtype, layout, device, pin_memory)
@parse_args("v", "i", "v", "v", "v", "v")
def empty_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
return zeros_like(g, input, dtype, layout, device, pin_memory)
def new_empty(g, self, sizes, dtype, layout, device, pin_memory=False):
self_dtype = sym_help._try_get_scalar_type(self)
if dtype is None and self_dtype is not None:
dtype = self_dtype
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return empty(g, sizes, dtype, layout, device, pin_memory)
def scalar_tensor(g, scalar, dtype, *options):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
scalar = g.op("Cast", scalar, to_i=sym_help.scalar_type_to_onnx[dtype])
return scalar
def tensor(g, data, dtype=None, device=None, requires_grad=False):
dtype = sym_help._get_const(dtype, "i", "dtype")
if sym_help._is_packed_list(data):
if dtype is None:
dtype = sym_help._unpack_list(data)[0].type().scalarType()
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
input_list = list()
for t in sym_help._unpack_list(data):
shape_reference = g.op("Constant", value_t=torch.LongTensor([1]))
t = g.op("Reshape", t, shape_reference)
t = g.op("Cast", t, to_i=sym_help.scalar_type_to_onnx[dtype])
input_list.append(t)
return g.op("Concat", *input_list, axis_i=0)
else:
if dtype is None:
dtype = data.type().scalarType()
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
if sym_help._is_list(data) and (sym_help._is_tensor_list(data) or sym_help._is_scalar_list(data)):
data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1)
return g.op("Cast", data, to_i=sym_help.scalar_type_to_onnx[dtype])
def as_tensor(g, data, dtype=None, device=None):
return tensor(g, data, dtype, device)
@parse_args("v", "i", "v", "v", "v")
def zeros(g, sizes, dtype, layout, device, pin_memory=False):
# NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args("v", "i", "v", "v", "v", "v")
def zeros_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
def new_zeros(g, self, sizes, dtype, layout, device, pin_memory=False):
self_dtype = sym_help._try_get_scalar_type(self)
if dtype is None and self_dtype is not None:
dtype = self_dtype
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return zeros(g, sizes, dtype, layout, device, pin_memory)
@parse_args("v", "i", "v", "v", "v")
def ones(g, sizes, dtype, layout, device, pin_memory=False):
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
@parse_args("v", "i", "v", "v", "v", "v")
def ones_like(g, input, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
shape = g.op("Shape", input)
if dtype is None:
dtype = 6 # float
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
def new_ones(g, self, sizes, dtype, layout, device, pin_memory=False):
self_dtype = sym_help._try_get_scalar_type(self)
if dtype is None and self_dtype is not None:
dtype = self_dtype
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return ones(g, sizes, dtype, layout, device, pin_memory)
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
const_value = sym_help._maybe_get_const(value, "t")
if sym_help._is_value(const_value):
dtype = 6 if dtype is None else dtype
tmp = zeros(g, sizes, dtype, layout, device)
return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = sym_help._get_const(dtype, "i", "dtype")
dtype = 6 if dtype is None else dtype
return g.op("ConstantOfShape", sizes,
value_t=torch.tensor([const_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
def full_like(g, input, fill_value, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
fill_value = sym_help._maybe_get_const(fill_value, "f")
dtype = sym_help._get_const(dtype, "i", "dtype")
dtype = 6 if dtype is None else dtype
if sym_help._is_value(fill_value):
tmp = zeros_like(g, input, dtype, layout, device)
fill_value = g.op("Cast", fill_value, to_i=sym_help.scalar_type_to_onnx[dtype])
return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1)))
else:
shape = g.op("Shape", input)
return g.op("ConstantOfShape", shape,
value_t=torch.tensor([fill_value], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
def new_full(g, self, size, fill_value, dtype, layout, device, pin_memory=False):
self_dtype = sym_help._try_get_scalar_type(self)
if dtype is None and self_dtype is not None:
dtype = self_dtype
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return full(g, size, fill_value, dtype, layout, device, pin_memory)
def eye(g, *args):
if len(args) == 5:
# aten::eye(n, dtype, layout, device, pin_memory)
n, dtype, layout, device, pin_memory = args
dim_size = sym_help._unsqueeze_helper(g, n, [0])
shape = g.op("Concat", dim_size, dim_size, axis_i=0)
tensor = zeros(g, shape, dtype, layout, device)
return g.op("EyeLike", tensor)
elif len(args) == 6:
# aten::eye(n, m, dtype, layout, device, pin_memory)
n, m, dtype, layout, device, pin_memory = args
shape = g.op("Concat", sym_help._unsqueeze_helper(g, n, [0]), sym_help._unsqueeze_helper(g, m, [0]), axis_i=0)
tensor = zeros(g, shape, dtype, layout, device)
return g.op("EyeLike", tensor)
else:
raise NotImplementedError("Unknown aten::eye signature")
def slice(g, self, *args):
if len(args) == 4:
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
dim, start, end, step = args
step = _parse_arg(step, "i")
if step != 1:
raise RuntimeError("step!=1 is currently not supported")
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
is_start_onnx_const = start.node().kind() == 'onnx::Constant'
is_end_onnx_const = end.node().kind() == 'onnx::Constant'
if ((not is_start_none) and (not is_start_onnx_const)) or \
((not is_end_none) and (not is_end_onnx_const)) or \
dim.node().kind() != 'onnx::Constant':
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX:
raise RuntimeError("Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice "
"is a deprecated experimental op. Please use statically allocated "
"variables or export to a higher opset version.")
else:
start_unsqueezed = sym_help._unsqueeze_helper(g, start, [0])
end_unsqueezed = sym_help._unsqueeze_helper(g, end, [0])
dim_unsqueezed = sym_help._unsqueeze_helper(g, dim, [0])
return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
else:
start = 0 if is_start_none else _parse_arg(start, 'i')
end = 9223372036854775807 if is_end_none else _parse_arg(end, 'i')
dim = _parse_arg(dim, 'i')
return sym_help._slice_helper(g, self, axes=[dim], starts=[start], ends=[end])
elif len(args) == 3:
# aten::slice(t[] l, int start, int end, int step) -> t[]
start, end, step = args
dim = 0
is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == 'NoneType'
is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == 'NoneType'
start = 0 if is_start_none else _parse_arg(start, 'i')
end = 9223372036854775807 if is_end_none else _parse_arg(end, 'i')
return sym_help._slice_helper(g, self, axes=[dim], starts=[start], ends=[end])
else:
raise NotImplementedError("Unknown aten::slice signature")
@parse_args("v", "f", "f")
def hardtanh(g, self, min_val, max_val):
return g.op("Clip", self, min_f=min_val, max_f=max_val)
@parse_args("v")
def hardswish(g, self):
input = g.op("Add", self, g.op("Constant", value_t=torch.tensor(3, dtype=torch.float)))
hardtanh_ = sym_help._hardtanh_helper(g, input,
g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)),
g.op("Constant", value_t=torch.tensor(6, dtype=torch.float)))
hardtanh_ = g.op("Div", hardtanh_, g.op("Constant", value_t=torch.tensor(6, dtype=torch.float)))
return g.op("Mul", self, hardtanh_)
@parse_args("v")
def hardsigmoid(g, self):
return g.op("HardSigmoid", self, alpha_f=1 / 6)
def alias(g, self):
return self
@parse_args("v", "i")
def unsqueeze(g, self, dim):
# Handle negative dim
if dim < 0:
rank = sym_help._get_tensor_rank(self)
if rank is not None:
warnings.warn("ONNX export unsqueeze with negative axis " + str(dim) +
" might cause the onnx model to be incorrect. " +
"Negative axis is not supported in ONNX. " +
"Axis is converted to " + str(dim + rank + 1) +
" based on input shape at export time. " +
"Passing an tensor of different rank in execution will be incorrect.")
dim = dim + rank + 1
else:
return _unimplemented("unsqueeze", "negative axis with unknown input rank")
return sym_help._unsqueeze_helper(g, self, axes_i=[dim])
@parse_args("v", "i", "i", "none")
def sort(g, self, dim, decending, out=None):
if out is not None:
_unimplemented("Sort", "Out parameter is not supported for sort")
self_sizes = sym_help._get_tensor_sizes(self)
try:
dim_size = self_sizes[dim]
except Exception:
dim_size = None
if dim_size is None:
return _unimplemented("Sort", "input size not accessible")
return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2)
def numel(g, self):
shape = g.op("Shape", self)
return g.op("ReduceProd", shape, keepdims_i=0)
@parse_args("v", "i", "i", "i", "i", "none")
def topk(g, self, k, dim, largest, sorted, out=None):
if out is not None:
_unimplemented("TopK", "Out parameter is not supported for topk")
if not largest:
_unimplemented("TopK", "Ascending TopK is not supported")
return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
def to(g, self, *args):
# ONNX doesn't have a concept of a device, so we ignore device casts
if len(args) == 4:
if args[0].node().kind() == "prim::device" or args[0].type().isSubtypeOf(ListType.ofInts()):
# aten::to(Tensor, Device, bool, bool, memory_format)
return self
else:
# TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]()
# In this case, the constant value is a tensor not int,
# so sym_help._maybe_get_const(args[0], 'i') would not work.
dtype = args[0]
if sym_help._is_value(args[0]) and args[0].node().kind() == "onnx::Constant":
tval = args[0].node()["value"]
if isinstance(tval, torch.Tensor):
if len(tval.shape) == 0:
tval = tval.item()
dtype = int(tval)
else:
dtype = tval
if sym_help._is_value(dtype) or isinstance(dtype, torch.Tensor):
# aten::to(Tensor, Tensor, bool, bool, memory_format)
dtype = args[0].type().scalarType()
return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[dtype])
else:
# aten::to(Tensor, ScalarType, bool, bool, memory_format)
# memory_format is ignored
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 5:
# aten::to(Tensor, Device, ScalarType, bool, bool, memory_format)
dtype = sym_help._get_const(args[1], "i", "dtype")
# memory_format is ignored
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 6:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor
dtype = sym_help._get_const(args[0], "i", "dtype")
# Layout, device and memory_format are ignored
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 7:
# aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor
dtype = sym_help._get_const(args[0], "i", "dtype")
# Layout, device and memory_format are ignored
return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::to signature")
def repeat(g, self, repeats):
dtype = 4 # int64
shape_ = ones_like(g, repeats, dtype)
self = g.op("Expand", self, shape_)
return g.op("Tile", self, repeats)
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
input = self
# if dim is None flatten
# By default, use the flattened input array, and return a flat output array
if sym_help._is_none(dim):
input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1])))
dim = 0
else:
dim = sym_help._maybe_get_scalar(dim)
repeats_dim = sym_help._get_tensor_rank(repeats)
repeats_sizes = sym_help._get_tensor_sizes(repeats)
input_sizes = sym_help._get_tensor_sizes(input)
if repeats_dim is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"repeats rank.")
if repeats_sizes is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"repeats size.")
if input_sizes is None:
raise RuntimeError("Unsupported: ONNX export of repeat_interleave for unknown "
"input size.")
input_sizes_temp = input_sizes.copy()
for idx, input_size in enumerate(input_sizes):
if input_size is None:
input_sizes[idx], input_sizes_temp[idx] = 0, -1
# Cases where repeats is an int or single value tensor
if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)):
if not sym_help._is_tensor(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
if input_sizes[dim] == 0:
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
"Unsupported along dimension with unknown input size")
else:
reps = input_sizes[dim]
repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None)
# Cases where repeats is a 1 dim Tensor
elif repeats_dim == 1:
if input_sizes[dim] == 0:
return sym_help._onnx_opset_unsupported_detailed("repeat_interleave", 9, 11,
"Unsupported along dimension with unknown input size")
assert repeats_sizes[0] == input_sizes[dim], "repeats must have the same size as input along dim"
reps = repeats_sizes[0]
else:
raise RuntimeError("repeats must be 0-dim or 1-dim tensor")
final_splits = list()
r_splits = sym_help._repeat_interleave_split_helper(g, repeats, reps, 0)
i_splits = sym_help._repeat_interleave_split_helper(g, input, reps, dim)
input_sizes[dim], input_sizes_temp[dim] = -1, 1
for idx, r_split in enumerate(r_splits):
i_split = unsqueeze(g, i_splits[idx], dim + 1)
r_concat = [g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[:dim + 1])),
r_split,
g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1:]))]
r_concat = g.op("Concat", *r_concat, axis_i=0)
i_split = expand(g, i_split, r_concat, None)
i_split = reshape(g, i_split, g.op("Constant", value_t=torch.LongTensor(input_sizes)))
final_splits.append(i_split)
return g.op("Concat", *final_splits, axis_i=dim)
@parse_args("v", "i")
def pixel_shuffle(g, self, upscale_factor):
dims = sym_help._get_tensor_sizes(self)
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
if any([i is None for i in dims[1:]]):
return _unimplemented("pixel_shuffle", "only support static input shape, except for batch size")
output_channel = dims[1] // upscale_factor // upscale_factor
after_view = view(g, self, g.op("Constant", value_t=torch.tensor([-1, output_channel, upscale_factor,
upscale_factor, dims[2], dims[3]])))
after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
return view(g, after_transpose,
g.op("Constant", value_t=torch.tensor([-1, output_channel, dims[2] * upscale_factor,
dims[3] * upscale_factor])))
def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
warnings.warn("Exporting a model to ONNX with a batch_size other than 1, " +
"with a variable length with " + variant + " can cause an error " +
"when running the ONNX model with a different batch size. " +
"Make sure to save the model with a batch size of 1, " +
"or define the initial states (h0/c0) as inputs of the model. ")
onnxActivations = ["Relu", "Tanh", "Sigmoid", "Affine", "LeakyRelu", "ThresholdedRelu",
"ScaledTanh", "HardSigmoid", "Elu", "Softsign", "Softplus"]
variantToOnnxActivationMap = dict(zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations))
weights_per_layer = 4 if has_biases else 2
# this means that projections are used inside LSTM, so need to tell user that it's not supported
if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * (1 + bidirectional):
return _unimplemented("LSTM", "LSTMs with projections")
assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
layer_weights = [all_weights[i:i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer)]
if batch_first:
# batch, seq, feat -> seq, batch, feat
input = g.op("Transpose", input, perm_i=[1, 0, 2])
if dropout and train:
return _unimplemented("RNN/GRU/LSTM", "dropout in training mode")
if variant.startswith("RNN"):
nonlinearity = variantToOnnxActivationMap[variant[4:].lower()]
variant = "RNN"
w_hh = all_weights[1]
hidden_size = sym_help._get_tensor_dim_size(w_hh, 1)
if hidden_size is None:
return _unimplemented("RNN/GRU/LSTM", "unknown hidden size")
unidirectional = not bidirectional
prev_output = input
h_outs = []
if variant == "RNN" or variant == "GRU":
h0 = initial_states
elif variant == "LSTM":
h0, c0 = initial_states
c_outs = []
sequence_lens = unused(g) if batch_sizes is None else batch_sizes
if variant == "GRU":
# pytorch is reset, input, hidden
# onnx is input, reset, hidden
reform_permutation = [(1, 2), (0, 1), (2, 3)]
elif variant == "LSTM":
# pytorch is input, forget, cell, output.
# onnx is input, output, forget, cell.
reform_permutation = [(0, 1), (3, 4), (1, 3)]
def reform_weights(g, w, n, intervals):
slices = [sym_help._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) for x, y in intervals]
return g.op("Concat", *slices, axis_i=0)
def transform_weights_no_bias(layer_index):
weights = layer_weights[layer_index]
if variant == "RNN":
weight_ih, weight_hh = weights
elif variant == "GRU" or variant == "LSTM":
weight_ih, weight_hh = \
[reform_weights(g, w, hidden_size, reform_permutation) for w in weights]
return tuple(sym_help._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh))
def transform_weights(layer_index):
weights = layer_weights[layer_index]
if variant == "RNN":
weight_ih, weight_hh, bias_ih, bias_hh = weights
elif variant == "GRU" or variant == "LSTM":
weight_ih, weight_hh, bias_ih, bias_hh = \
[reform_weights(g, w, hidden_size, reform_permutation) for w in weights]
bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0)
return tuple(sym_help._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh, bias_concat))
def retrieve_state(x, start, end):
return x if num_layers == 1 else sym_help._slice_helper(g, x, axes=[0], starts=[start], ends=[end])
for i in range(num_layers):
if unidirectional:
if weights_per_layer == 4:
weight_ih, weight_hh, bias_concat = transform_weights(i)
else:
weight_ih, weight_hh = transform_weights_no_bias(i)
bias_concat = unused(g)
state_indices = i, i + 1
else:
if weights_per_layer == 4:
weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0)
else:
weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i)
weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1)
bias_concat = unused(g)
weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0)
weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0)
state_indices = 2 * i, 2 * i + 2
inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
inputs.append(retrieve_state(h0, *state_indices))
if variant == "LSTM":
inputs.append(retrieve_state(c0, *state_indices))
extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"}
if variant == "RNN":
if bidirectional:
activation = [nonlinearity, nonlinearity]
else:
activation = [nonlinearity]
prev_output, h_out = g.op("RNN", *inputs, outputs=2,
hidden_size_i=hidden_size,
activations_s=activation,
**extra_kwargs)
elif variant == "GRU":
prev_output, h_out = g.op("GRU", *inputs, outputs=2,
hidden_size_i=hidden_size,
linear_before_reset_i=1,
**extra_kwargs)
elif variant == "LSTM":
prev_output, h_out, c_out = g.op("LSTM", *inputs, outputs=3,
hidden_size_i=hidden_size,
**extra_kwargs)
if bidirectional:
# The ONNX RNN/GRU/LSTM produce an output of dimensions
# seq_len, num_directions, batch, hidden_size
# We have to convert to match pytorch's expected
# seq_len, batch, num_directions * hidden_size
# by first moving num_directions before hidden_size with
# Transpose, and then combining it with hidden_size
# with Reshape.
prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3])
prev_output = g.op("Reshape", prev_output, g.op("Constant", value_t=torch.LongTensor([0, 0, -1])))
else:
prev_output = sym_help._squeeze_helper(g, prev_output, [1])
h_outs.append(h_out)
if variant == "LSTM":
c_outs.append(c_out)
if batch_first:
# seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size
prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2])
h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0)
if variant == "RNN" or variant == "GRU":
return prev_output, h_outs
elif variant == "LSTM":
c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0)
return prev_output, h_outs, c_outs
@parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
hidden, weight = sym_help._unpack_list(hidden_v), sym_help._unpack_list(weight_v)
return _generic_rnn(g, "LSTM", input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
hidden, weight = sym_help._unpack_list(hidden_v), sym_help._unpack_list(weight_v)
return _generic_rnn(g, "LSTM", input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def lstm(g, *args):
if sym_help._is_tensor_list(args[3]):
return _lstm_packed(g, *args)
else:
return _lstm_full(g, *args)
def _one_hidden_rnn(kind):
@parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i")
def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
weight = sym_help._unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_first)
@parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i")
def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
weight = sym_help._unpack_list(weight_v)
return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
dropout, train, bidirectional, batch_sizes=batch_sizes)
def symbolic(g, *args):
if sym_help._is_tensor_list(args[3]):
return _rnn_packed(g, *args)
else:
return _rnn_full(g, *args)
return symbolic
gru = _one_hidden_rnn("GRU")
rnn_tanh = _one_hidden_rnn("RNN_TANH")
rnn_relu = _one_hidden_rnn("RNN_RELU")
@parse_args("v", "i")
def _dim_arange(g, like, dim):
like_shape = g.op("Shape", like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("_caffe2::Range", stop)
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
return arange(g, stop, 4, None, None, None)
def detach(g, input):
# Erase aten::detach nodes because ONNX is inference only
return input
@parse_args("v", "i")
def contiguous(g, input, memory_format):
if memory_format > 2: # allower values are any, preserve and contiguous_format
raise RuntimeError("onnx memory_format support is not implemented")
return input
@parse_args("v", "v", "i")
def _pack_padded_sequence(g, input, lengths, batch_first):
# There currently is no PackPadded operator in ONNX. We rely on an
# optimization pass to remove this later. It is an error if all
# PackPadded operators cannot be optimized out.
if batch_first:
input = g.op("Transpose", input, perm_i=[1, 0, 2])
if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
raise RuntimeError("Lengths must be a Tensor for ONNX export")
# We know it's a TensorType so this check is now safe.
# It's really only necessary because those operators expand to something that
# only works with int32 types in Caffe2...
if lengths.type().scalarType() != "Int":
lengths = _cast_Int(g, lengths, False) # type: ignore[name-defined]
return g.op("prim::PackPadded", input, lengths, outputs=2)
@parse_args("v", "v", "i", "t", "v")
def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total_length):
# Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
# It is only useful/used when training using data_parallel model, so
# It shouldn't be relevant for ONNX anyway
data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
if batch_first:
data = g.op("Transpose", data, perm_i=[1, 0, 2])
return data, lengths
def randn(g, shapes, dtype, *options):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
shape = sym_help._maybe_get_const(shapes, "is")
if sym_help._is_value(shape):
shape_const = g.op("ConstantOfShape", shapes,
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[6]))
return g.op("RandomNormalLike", shape_const, dtype_i=sym_help.scalar_type_to_onnx[dtype])
return g.op("RandomNormal", shape_i=shape)
def rand(g, shapes, dtype, *options):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
shape = sym_help._maybe_get_const(shapes, "is")
if sym_help._is_value(shape):
shape_const = g.op("ConstantOfShape", shapes,
value_t=torch.tensor([0], dtype=sym_help.scalar_type_to_pytorch_type[6]))
return g.op("RandomUniformLike", shape_const, dtype_i=sym_help.scalar_type_to_onnx[dtype])
return g.op("RandomUniform", shape_i=shape)
def randn_like(g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
return g.op("RandomNormalLike", self, dtype_i=sym_help.scalar_type_to_onnx[dtype])
def rand_like(g, self, dtype, layout=None, device=None, pin_memory=False, memory_format=None):
dtype = sym_help._get_const(dtype, "i", "dtype")
if dtype is None:
dtype = 6 # float
return g.op("RandomUniformLike", self, dtype_i=sym_help.scalar_type_to_onnx[dtype])
@parse_args("v", "f", "f", "i", "none")
def rrelu(g, input, lower, upper, training, generator):
p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower)
return g.op("PRelu", input, p)
def bernoulli(g, input, generator=None, out=None):
if out is not None:
_unimplemented("Bernoulli", "out parameter is not supported for bernoulli")
if generator is not None and not sym_help._is_none(generator):
_unimplemented("Bernoulli", "generator is not supported for bernoulli")
dtype = sym_help._try_get_scalar_type(input)
if dtype is None:
return _unimplemented("Bernoulli", "input dtype not accessible")
p = g.op('RandomUniformLike', input, high_f=1.0, low_f=0.0, dtype_i=sym_help.cast_pytorch_to_onnx[dtype])
output = g.op('Less', p, input)
return g.op("Cast", output, to_i=sym_help.cast_pytorch_to_onnx[dtype])
@parse_args('v')
def log_sigmoid(g, input):
p = g.op("Sigmoid", input)
return g.op("Log", p)
@parse_args("v")
def erf(g, input):
return g.op("Erf", input)
@parse_args("v", "i", "i")
def flatten(g, input, start_dim, end_dim):
dim = sym_help._get_tensor_rank(input)
if dim is None:
return _unimplemented("dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.")
# TODO: remove this as onnx opset 11 spec allows negative axes
if end_dim < 0 :
end_dim = dim + end_dim
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1 and end_dim == dim - 1 :
return g.op("Flatten", input, axis_i=start_dim)
if start_dim == 0 and end_dim == dim - 2 :
return g.op("Flatten", input, axis_i=end_dim + 1)
return sym_help._flatten_helper(g, input, start_dim, end_dim, dim)
# Emitted from `torch.nonzero(x, as_tuple=False)`
@parse_args("v")
def nonzero(g, input):
return t(g, g.op("NonZero", input))
# Emitted from `torch.nonzero(x, as_tuple=True)`
def nonzero_numpy(g, input, _outputs=None):
return unbind(g, nonzero(g, input), 1, _outputs=_outputs)
@parse_args("v")
def isnan(g, input):
output = g.op("IsNaN", input)
return output
def _any(g, input):
input = _cast_Long(g, input, False) # type: ignore[name-defined]
input_sum = sym_help._reducesum_helper(g, input, keepdims_i=0)
return gt(g, input_sum, g.op("Constant", value_t=torch.LongTensor([0])))
def _all(g, input):
return g.op("Not", _any(g, g.op("Not", input)))
@parse_args("v", "i", "i", "i")
def narrow(g, input, dim, start, length):
return sym_help._slice_helper(g, input, axes=[dim], starts=[start], ends=[start + length])
def argmax(g, input, dim, keepdim):
if sym_help._is_none(dim):
flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1])))
return g.op("ArgMax", flattened, axis_i=0, keepdims_i=False)
else:
dim = _parse_arg(dim, "i")
keepdim = _parse_arg(keepdim, "i")
return g.op("ArgMax", input, axis_i=dim, keepdims_i=keepdim)
def argmin(g, input, dim, keepdim):
if sym_help._is_none(dim):
flattened = reshape(g, input, g.op("Constant", value_t=torch.tensor([-1])))
return g.op("ArgMin", flattened, axis_i=0, keepdims_i=False)
else:
dim = _parse_arg(dim, "i")
keepdim = _parse_arg(keepdim, "i")
return g.op("ArgMin", input, axis_i=dim, keepdims_i=keepdim)
@parse_args("v", "i", "v", "v")
def scatter(g, self, dim, index, src):
src_type = src.type().scalarType()
src = sym_help._maybe_get_scalar(src)
if sym_help._is_value(src):
return g.op("Scatter", self, index, src, axis_i=dim)
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
if self.type().scalarType() != src_type:
src = g.op("Cast", src, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim)
@parse_args("v", "i", "v", "v")
def scatter_add(g, self, dim, index, src):
dtype = sym_help._try_get_scalar_type(self)
if dtype is None:
return _unimplemented("scatter_add", "input dtype not accessible")
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
dtype = sym_help.scalar_type_to_pytorch_type[dtype]
sizes = sym_help._get_tensor_sizes(self, allow_nonstatic=False)
if sizes:
to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=dtype))
else:
dtype = sym_help.scalar_type_to_pytorch_type.index(dtype)
to_add = zeros_like(g, self, dtype)
to_add = sym_help._scatter_helper(g, to_add, dim, index, src)
return add(g, self, to_add)
def log2(g, self):
_ln2 = 0.693147180559945309
return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln2])))
def prim_shape(g, self):
return g.op("Shape", self)
def prim_max(g, self, other):
return g.op("Max", self, other)
def prim_min(g, self, other=None):
if not other:
if (sym_help._is_packed_list(self)):
self = stack(g, self, g.op("Constant", value_t=torch.tensor([0])))
return min(g, self)
return min(g, self, other)
def prim_data(g, self):
return self
def is_floating_point(g, self):
if sym_help._is_fp(self):
return g.op("Constant", value_t=torch.BoolTensor([1]))
return g.op("Constant", value_t=torch.BoolTensor([0]))
def __isnot_(g, self, other):
if sym_help._is_none(other):
if sym_help._is_none(self):
return g.op("Constant", value_t=torch.BoolTensor([0]))
return g.op("Constant", value_t=torch.BoolTensor([1]))
return ne(g, self, other)
# exists to refine the type of the Value
# if x is an optional Tensor, unchecked_cast will cast
# x to Tensor, so the rest of the graph knows that x is a Tensor
# this doesn't do anything in runtime and is a noop in ONNX
def prim_unchecked_cast(g, self):
return self
def prim_dtype(g, self):
dtype = sym_help._try_get_scalar_type(self)
if dtype is None:
dtype = "Float"
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return g.op("Constant", value_t=torch.tensor(dtype))
# tolist is currently supported only for 1D input tensors.
# dim_val and elem_ty_val represent dimension and type annotations
# that need to match dimension and type of the input tensor.
def prim_tolist(g, input, dim_val, elem_ty_val):
dim = sym_help._maybe_get_const(dim_val, "i")
if dim > 1:
return _unimplemented("prim_tolist", "dim_val > 1")
return input
@parse_args('v', 'i')
def one_hot(g, self, num_classes):
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
depth = g.op("Constant", value_t=torch.LongTensor([num_classes]))
return g.op("OneHot", self, depth, values, axis_i=-1)
@parse_args("v", "i", "v", "v")
def gather(g, self, dim, index, sparse_grad=False):
if sym_help._maybe_get_const(sparse_grad, "i"):
return _unimplemented("gather", "sparse_grad == True")
# NOTE: This workaround is needed since GatherElement is only supported
# since opset 11, and Gather in ONNX is not the same as torch.gather.
dtype = self.type().scalarType()
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
index = g.op("Cast", g.op("OneHot", index, depth, values, axis_i=dim), to_i=sym_help.cast_pytorch_to_onnx[dtype])
mul = g.op("Mul", sym_help._unsqueeze_helper(g, self, [dim + 1]), index)
return sym_help._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0)
@parse_args("v", "is", "i", "i")
def _var_mean(g, input, dim, correction, keepdim):
if dim is None:
mean = g.op("ReduceMean", input, keepdims_i=0)
t_mean = mean
num_elements = numel(g, input)
else:
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1)
redudced_dims = g.op("Shape", input)
# dim could contain one or multiple dimensions
redudced_dims = g.op("Gather", redudced_dims, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
sub_v = g.op("Sub", input, t_mean)
sqr_sub = g.op("Mul", sub_v, sub_v)
keepdim_mean = 0 if dim is None else keepdim
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
# Correct bias in calculating variance, by dividing it over (N - correction) instead on N
if correction is None:
correction = 1
if correction != 0:
num_elements = g.op("Cast", num_elements, to_i=sym_help.cast_pytorch_to_onnx["Float"])
one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float))
mul = g.op("Mul", var, num_elements)
var = g.op("Div", mul, g.op("Sub", num_elements, one))
return var, mean
def std(g, input, *args):
var, _ = var_mean(g, input, *args)
return g.op("Sqrt", var)
def var(g, input, *args):
var, _ = var_mean(g, input, *args)
return var
# var_mean (and all variance-related functions) has multiple signatures, so need to manually figure
# out the correct arguments:
# aten::var_mean(Tensor self, bool unbiased)
# aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False)
# aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False)
def var_mean(g, input, *args):
if len(args) == 1:
return _var_mean(g, input, None, args[0], None)
else:
return _var_mean(g, input, *args)
def std_mean(g, input, *args):
var, mean = var_mean(g, input, *args)
return g.op("Sqrt", var), mean
@parse_args("v", "is", "i")
def logsumexp(g, input, dim, keepdim):
return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim)
def arange(g, *args):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", *args, operator_s="arange")
def _get_arange_dtype(dtype):
dtype = sym_help._maybe_get_const(dtype, "i")
return dtype
def _float_step_convert(range_tensor):
if sym_help._is_fp(range_tensor):
range_tensor = g.op("Cast", g.op("Ceil", range_tensor), to_i=sym_help.scalar_type_to_onnx[4])
return range_tensor
if len(args) == 2 or len(args) == 5:
if len(args) == 2:
# aten::arange(Scalar end, Tensor out)
dtype = None
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[1])
dtype, end, start, step = sym_help._arange_cast_helper(g, end=args[0], dtype=dtype)
end = sym_help._unsqueeze_helper(g, end, [0])
range_tensor = _float_step_convert(end)
arange_tensor = sym_help._squeeze_helper(g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1])
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 4 or len(args) == 7:
if len(args) == 4:
# aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
dtype = None
else:
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[3])
dtype, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], step=args[2], dtype=dtype)
step = sym_help._unsqueeze_helper(g, step, [0])
end = sym_help._unsqueeze_helper(g, end, [0])
start = sym_help._unsqueeze_helper(g, start, [0])
range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step))
arange_tensor = sym_help._squeeze_helper(g, nonzero(g, ones(g, range_tensor, None, None, None)), [1])
arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start)
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
elif len(args) == 6:
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[2])
dtype, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], dtype=dtype)
end = sym_help._unsqueeze_helper(g, end, [0])
start = sym_help._unsqueeze_helper(g, start, [0])
range_tensor = _float_step_convert(g.op("Sub", end, start))
arange_tensor = g.op("Add", sym_help._squeeze_helper(g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1]), start)
return g.op("Cast", arange_tensor, to_i=sym_help.scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.")
def linspace(g, start, end, steps, dtype, layout, device, pin_memory):
step = div(g, sub(g, end, start), sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))))
end_epsilon = g.op("Add", step, end)
return sym_help._arange_helper(g, start, end_epsilon, step, dtype, None, None, None)
def masked_fill(g, self, mask, value):
mask = _cast_Bool(g, mask, False) # type: ignore[name-defined]
value = sym_help._maybe_get_scalar(value)
return g.op("Where", mask, sym_help._if_scalar_type_as(g, value, self), self)
def index(g, self, index):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, index, operator_s="index")
if sym_help._is_packed_list(index):
indices = sym_help._unpack_list(index)
else:
indices = [index]
def try_mask_to_index(index):
if not sym_help._is_none(index) and (index.type().scalarType() == "Byte" or index.type().scalarType() == "Bool"):
if sym_help._export_onnx_opset_version < 9:
raise RuntimeError("Exporting masked indices are only supported after ONNX opset 9.")
warnings.warn("Exporting aten::index operator with indices of type Byte. "
"Only 1-D indices are supported. In any other case, "
"this will produce an incorrect ONNX graph.")
index = sym_help._squeeze_helper(g, nonzero(g, index), [1])
return index
indices = [try_mask_to_index(idx) for idx in indices]
if len(indices) == 1:
return sym_help._select_helper(g, self, 0, indices[0], apply_reshape=False)
else:
# Multiple tensors as indices. Each tensor could either be
# 1. prim::Constant()
# representing ":" in python indexing. E.g. tensor[:, :]
# 2. prim::Constant[value=...] or tensor output
# representing advanced indexing. E.g. tensor[[0, 1], [2, 0]].
# For more info on advanced indexing,
# check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
# Consider a general case of
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n]
# where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":".
# Same results can be achieved through transposing t into
# t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n]
# and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t
# and process the tensor indices.
# t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n]
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j))
# After gather, reshape and transpose back.
adv_idx_indices = [i for i, idx in enumerate(indices) if not sym_help._is_none(idx)]
if len(adv_idx_indices) == 0:
return self
elif len(adv_idx_indices) == 1:
return index_select(g, self, adv_idx_indices[0], indices[adv_idx_indices[0]])
else:
rank = sym_help._get_tensor_rank(self)
if rank is None:
raise NotImplementedError("Unsupported aten::index operator of advanced indexing on tensor of unknown rank, " +
"try turning on shape and type propagate during export: " +
"torch.onnx._export(..., propagate=True).")
# TODO: If indexing is supported natively in ONNX in future opsets,
# update the warning to recommend exporting with higher opset version.
warnings.warn("Exporting aten::index operator of advanced indexing in opset " +
str(sym_help._export_onnx_opset_version) +
" is achieved by combination of multiple ONNX operators, " +
"including Reshape, Transpose, Concat, and Gather. " +
"If indices include negative values, the exported graph will produce incorrect results.")
adv_idx_count = len(adv_idx_indices)
shape_tensor = _shape_as_tensor(g, self)
dim_tensor_list = [
g.op("Gather", shape_tensor, g.op("Constant", value_t=torch.LongTensor([dim])), axis_i=0) for dim in range(rank)
]
self = g.op("Transpose", self, perm_i=adv_idx_indices + [i for i in range(rank) if i not in adv_idx_indices])
self = g.op("Flatten", self, axis_i=adv_idx_count)
# Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well.
cum_adv_index = indices[adv_idx_indices[-1]]
multiplier = dim_tensor_list[adv_idx_indices[-1]]
for i in range(adv_idx_count - 2, -1, -1):
adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier)
cum_adv_index = g.op("Add", cum_adv_index, adv_index)
multiplier = g.op("Mul", multiplier, dim_tensor_list[adv_idx_indices[i]])
# perform gather
self = index_select(g, self, 0, cum_adv_index)
cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index)
# check if all advanced indices are consecutive.
# Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
# to understand how the subarray position is decided.
if adv_idx_indices == list(range(adv_idx_indices[0], adv_idx_indices[-1] + 1)):
# unfold regular index axes
folded_adv_idx_shape_list = [g.op("Constant", value_t=torch.LongTensor([-1]))] \
+ [dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices]
folded_adv_idx_shape = g.op("Concat", *folded_adv_idx_shape_list, axis_i=0)
self = g.op("Reshape", self, folded_adv_idx_shape)
# Transpose folded advanced indexed axis to its original location.
adv_idx_permute = list(range(1, adv_idx_indices[0] + 1)) \
+ [0] + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1))
self = g.op("Transpose", self, perm_i=adv_idx_permute)
# unfold advanced index axes
final_shape_list = [dim_tensor_list[i] for i in range(adv_idx_indices[0])] \
+ [cum_adv_index_shape_tensor] \
+ [dim_tensor_list[i] for i in range(adv_idx_indices[0], rank) if i not in adv_idx_indices]
final_shape = g.op("Concat", *final_shape_list, axis_i=0)
else:
final_shape = g.op(
"Concat",
cum_adv_index_shape_tensor,
*[dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices],
axis_i=0)
return g.op("Reshape", self, final_shape)
@parse_args("v", "is", "i")
def frobenius_norm(g, self, dim=None, keepdim=False):
sqr = g.op("Mul", self, self)
sumsqr = sym_help._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim)
return g.op("Sqrt", sumsqr)
@parse_args("v", "i", "b", "v")
def multinomial(g, input, num_samples, replacement=False, generator=None):
if generator is not None and not sym_help._is_none(generator):
_unimplemented("Multinomial", "generator is not supported for multinomial")
if not replacement and num_samples > 1:
_unimplemented("Multinomial", "replacement=False when num_samples > 1 is not supported for multinomial")
log_input = log(g, input)
return g.op("Multinomial", log_input,
dtype_i=sym_help.cast_pytorch_to_onnx["Long"],
sample_size_i=num_samples)
def baddbmm(g, self, batch1, batch2, beta, alpha):
dtype = self.type().scalarType()
batch_mul = matmul(g, batch1, batch2)
mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
mul_b = mul(g, self, g.op("Cast", beta, to_i=sym_help.cast_pytorch_to_onnx[dtype]))
return add(g, mul_a, mul_b)
def meshgrid(g, tensor_list):
tensors = [view(g, t, g.op("Constant", value_t=torch.LongTensor([-1]))) for t in sym_help._unpack_list(tensor_list)]
tensors_shape = [g.op("Shape", t) for t in tensors]
out_shape = g.op("Concat", *tensors_shape, axis_i=0)
out = []
for i, t in enumerate(tensors):
shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len(tensors)
shape_i[i] = tensors_shape[i]
t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0))
out.append(g.op("Expand", t_reshaped, out_shape))
return g.op("prim::ListConstruct", *out)
def remainder(g, input, other):
div = g.op("Div", input, other)
if sym_help._is_fp(input) or sym_help._is_fp(other):
div = g.op("Floor", div)
quo = g.op("Mul", div, other)
return g.op("Sub", input, quo)
def gelu(g, self):
_sqrt2 = 1.4142135623730951
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)))
return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)))
@parse_args("v", "i", "v", "v", "f", "i")
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, weight, bias, num_groups_i=num_groups,
eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm")
channel_size = sym_help._get_tensor_dim_size(input, 1)
if channel_size is not None:
assert channel_size % num_groups == 0
input_rank = sym_help._get_tensor_rank(input)
if input_rank is None:
return _unimplemented("group_norm", "unknown input rank")
# 0 in the shape list keeps dimension value unchanged.
shape = [0, num_groups, -1]
input_reshaped = g.op("Reshape", input, g.op("Constant", value_t=torch.LongTensor(shape)))
# C is always divisible by num_groups
# Due to shape difference. we need to apply weight and bias after
# instance norm computation and reshape
weight_ = g.op("Constant", value_t=torch.tensor([1.] * num_groups).type(
"torch." + input.type().scalarType() + "Tensor"))
bias_ = g.op("Constant", value_t=torch.tensor([0.] * num_groups).type(
"torch." + input.type().scalarType() + "Tensor"))
norm_reshaped = g.op("InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps)
norm = g.op("Reshape", norm_reshaped, g.op("Shape", input))
if weight is None or weight.node().mustBeNone():
weight_value = torch.tensor([1.]).type(
"torch." + input.type().scalarType() + "Tensor")
weight = g.op("Constant", value_t=weight_value)
if bias is None or bias.node().mustBeNone():
bias_value = torch.tensor([0.]).type(
"torch." + input.type().scalarType() + "Tensor")
bias = g.op("Constant", value_t=bias_value)
# Norm has shape [N, C, *] so we reshape weight and bias to [C, *]
axes = list(range(1, input_rank - 1))
return add(g, mul(g, norm, sym_help._unsqueeze_helper(g, weight, axes)), sym_help._unsqueeze_helper(g, bias, axes))
@parse_args("v", "v", "i")
def _weight_norm(g, weight_v, weight_g, dim):
rank = sym_help._get_tensor_rank(weight_v)
if rank is not None:
# W = g * ((v) / ||v||)
# Compute norm_except_dim for l2 norm. dim = None means over all dims
# torch's weight_norm module sets dim = -1 if it's None.
# This conflicts the logic for negative axes to access dims backwards
# TODO: Might need a fix in torch group_norm module
axes = list(range(rank))
if dim is not None:
if dim < -1:
dim += rank
if dim != -1:
axes.remove(dim)
norm_v = norm(g, weight_v, 2, axes, 1)
div = g.op("Div", weight_v, norm_v)
return g.op("Mul", div, weight_g)
elif sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", weight_v, weight_g, dim_i=dim, operator_s="_weight_norm")
else:
raise RuntimeError("Unsupported: ONNX export of _weight_norm for tensor "
"of unknown rank.")
def dim(g, self):
"""Implement the dim functionality available for a pytorch tensor in ONNX"""
# ONNX does not support dim directly in this opset so we can use 2 ops to get the info
shape = g.op("Shape", self)
return g.op("Size", shape)
def __getitem_(g, self, i):
return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i)
def item(g, self):
return self
def take(g, self, index):
self_flattened = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
out = index_select(g, self_flattened, 0, index)
out = reshape_as(g, out, index)
return out
def _kl_div_log_target_impl(g, input, target):
diff_ = sub(g, target, input)
exp_ = exp(g, target)
output = mul(g, exp_, diff_)
return output
def _kl_div_non_log_target_impl(g, input, target):
log_ = log(g, target)
diff_ = sub(g, log_, input)
output_pos = mul(g, target, diff_)
zeros_ = zeros_like(g, output_pos)
mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
output = where(g, mask_, output_pos, zeros_)
return output
@parse_args("v", "v", "i", "b")
def kl_div(g, input, target, reduction, log_target):
if log_target:
output = _kl_div_log_target_impl(g, input, target)
else:
output = _kl_div_non_log_target_impl(g, input, target)
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output, keepdims_i=0)
elif reduction == 2:
return sym_help._reducesum_helper(g, output, keepdims_i=0)
else:
return sym_help._onnx_unsupported("kl_div with reduction other than none, mean, or sum.")
@parse_args("v", "v", "is", "i")
def as_strided(g, self, sizes, strides, offset=None):
sizes = sym_help._maybe_get_const(sizes, "is")
rank = len(strides)
self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
ind: Optional[torch.Tensor]
if not sym_help._is_value(sizes):
ind = torch.tensor([0], dtype=torch.long)
for i, (size, stride) in enumerate(zip(sizes, strides)):
r_size = [1] * rank
r_size[i] = -1
ind = ind + torch.arange(size).view(r_size) * stride
if offset:
ind = ind + offset
return g.op("Gather", self_1d, g.op("Constant", value_t=ind))
else:
ind = None
for i, stride in enumerate(strides):
r_size = [1] * rank
r_size[i] = -1
size = select(g, sizes, g.op("Constant", value_t=torch.tensor([0])), g.op("Constant", value_t=torch.tensor(i)))
tmp_ind = g.op("Reshape", arange(g, size, 4, None, None, None), g.op("Constant", value_t=torch.tensor(r_size)))
tmp_ind = g.op("Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])))
if ind is None:
ind = tmp_ind
else:
ind = g.op("Add", ind, tmp_ind)
if offset:
ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset])))
return g.op("Gather", self_1d, ind)
def __derive_index(g, index, start, step):
return g.op("Add", start, g.op("Mul", index, step))
# Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp
# if (step > 0 && lo < hi) {
# push(stack, 1 + (hi - 1 - lo) / step);
# } else if (step < 0 && lo > hi) {
# push(stack, 1 + (lo - 1 - hi) / (0 - step));
# } else {
# push(stack, 0);
# }
def __range_length(g, lo, hi, step):
sub = g.op("Sub", hi, lo)
div = g.op("Ceil", true_divide(g, sub, step))
return g.op("Cast", div, to_i=sym_help.cast_pytorch_to_onnx["Long"])
def linear(g, input, weight, bias):
rank = sym_help._get_tensor_rank(input)
weight = t(g, weight)
if rank == 2 and not bias.node().mustBeNone():
alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
output = addmm(g, bias, input, weight, alpha, beta)
else:
output = matmul(g, input, weight)
if not bias.node().mustBeNone():
output = add(g, bias, output)
return output
@parse_args("v", "b", "i", "v", "v", "v", "v")
def hann_window(g, window_length, periodic=True, dtype=None, layout=None, device=None, pin_memory=None, requires_grad=False):
if dtype is None:
dtype = torch.get_default_dtype()
if sym_help._dtype_is_fp(dtype) is False:
dtype = torch.float
dtype = sym_help.scalar_type_to_pytorch_type.index(dtype)
n_array = arange(g, window_length, 4, None, None, None)
output = g.op("Cast", n_array, to_i=sym_help.cast_pytorch_to_onnx["Float"])
output = mul(g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output)
if periodic is False:
window_length = sub(g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)))
output = div(g, output, window_length)
output = g.op("Cast", square(g, sin(g, output)), to_i=sym_help.scalar_type_to_onnx[dtype])
return output
def mv(g, self, vec):
return matmul(g, self, vec)
@parse_args('v', 'v')
def fill(g, self, value):
dtype = self.type().scalarType()
if dtype is None:
dtype = 6 # float
else:
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
return full_like(g, self, value, dtype)
def index_add(g, self, dim, index, other):
warnings.warn("Warning: ONNX export does not support duplicated values in 'index' field, " +
"this will cause the ONNX model to be incorrect.")
from torch.onnx.symbolic_opset9 import scatter_add
dim = sym_help._maybe_get_const(dim, "i")
if dim is None:
raise NotImplementedError("ONNX export does NOT support exporting 'index_add_()' function with " +
"unknown 'dim' value.")
self_dim_rank = sym_help._get_tensor_rank(self)
other_dim_rank = sym_help._get_tensor_rank(other)
if self_dim_rank is None or other_dim_rank is None:
raise NotImplementedError("ONNX export does NOT support exporting 'index_add_()' function while " +
"the rank of self tensor or tensor to be added is unknown.")
if other_dim_rank != self_dim_rank:
delta = self_dim_rank - other_dim_rank
for i in range(delta):
other = sym_help._unsqueeze_helper(g, other, [sym_help._get_tensor_rank(other)])
other_dim_size = sym_help._get_tensor_dim_size(other, dim)
self_dim_size = sym_help._get_tensor_dim_size(self, dim)
if (other_dim_size is not None) and (self_dim_size is not None):
if other_dim_size > self_dim_size:
raise NotImplementedError("ONNX export does NOT support exporting 'index_add_()' function with " +
"duplicated values in 'index' parameter yet.")
# Construct a new shape. It's almost as same as self except the size of the 'dim'
# dimension is 1, so that we can expand other dimensions as expected.
new_shape_axes = list(range(self_dim_rank))
new_shape_starts = [0 for i in range(self_dim_rank)]
new_shape_ends = [maxsize
if (i != dim)
else
1
for i in range(self_dim_rank)]
new_shape = sym_help._slice_helper(g,
self,
axes=new_shape_axes,
starts=new_shape_starts,
ends=new_shape_ends)
other = expand_as(g, other, new_shape)
for i in range(dim):
index = sym_help._unsqueeze_helper(g, index, [0])
for i in range(self_dim_rank - dim - 1):
index = sym_help._unsqueeze_helper(g, index, [sym_help._get_tensor_rank(index)])
return scatter_add(g, self, dim, expand_as(g, index, other), other)
@parse_args('v', 'is', 'is')
def roll(g, self, shifts, dims):
assert len(shifts) == len(dims)
result = self
for i in range(len(shifts)):
shapes = []
shape = sym_help._slice_helper(g,
result,
axes=[dims[i]],
starts=[-shifts[i]],
ends=[maxsize])
shapes.append(shape)
shape = sym_help._slice_helper(g,
result,
axes=[dims[i]],
starts=[0],
ends=[-shifts[i]])
shapes.append(shape)
result = g.op("Concat", *shapes, axis_i=dims[i])
return result
def broadcast_tensors(g, self):
all_tensors = sym_help._unpack_list(self)
t_with_final_shape = zeros_like(g, all_tensors[0])
# Add operator supports multidirectional broadcasting. So we leverage this function
# to infer the final shape generated by the broadcast.
for t in all_tensors:
t_with_final_shape = add(g, t_with_final_shape, t)
t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors]
return g.op("prim::ListConstruct", *t_list)