blob: 73580bb975e9ebfb6723874bb99cff3c1a2798d0 [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_helper import _black_list_in_opset
from torch.onnx.symbolic_opset9 import expand
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 11
black_listed_operators = [
"hardtanh"
]
for black_listed_op in black_listed_operators:
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
def clamp(g, self, min, max):
dtype = self.type().scalarType()
def _cast_if_not_none(tensor, dtype):
if tensor is not None and not sym_help._is_none(tensor):
return g.op("Cast", tensor, to_i=sym_help.cast_pytorch_to_onnx[dtype])
else:
return tensor
if dtype is not None:
min = _cast_if_not_none(min, dtype)
max = _cast_if_not_none(max, dtype)
return g.op("Clip", self, min, max)
@parse_args('v', 'i')
def pixel_shuffle(g, self, upscale_factor):
dims = self.type().sizes()
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, align_corners=None):
align_corners = sym_help._maybe_get_scalar(align_corners)
coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
else "align_corners" if align_corners else "pytorch_half_pixel"
empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
input_size = input.type().sizes()
input_size = g.op("Constant", value_t=torch.tensor(input_size[0:2], dtype=torch.int64))
output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"])
output_size = g.op("Concat", input_size, output_size, axis_i=0)
return g.op("Resize",
input,
empty_tensor, # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize"
empty_tensor, # scales is not needed since we are sending out_size
output_size,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s="floor") # only valid when mode="nearest"
return symbolic_fn
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")
upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear")
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear")
upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, "cubic")
def __interpolate(g, input, size, scale_factor, mode, align_corners):
mode = sym_help._maybe_get_const(mode, 's')
if 'linear' in mode:
mode = 'linear'
if 'cubic' in mode:
mode = 'cubic'
align_corners = sym_help._maybe_get_const(align_corners, 'b')
align_corners = False if sym_help._is_none(align_corners) else align_corners
coordinate_transformation_mode = "asymmetric" if mode == "nearest" \
else "align_corners" if align_corners else "pytorch_half_pixel"
# roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize"
roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
if not sym_help._is_none(size) :
input_size = input.type().sizes()
input_size = g.op("Constant", value_t=torch.tensor(input_size[0:2], dtype=torch.int64))
is_scalar = ((sym_help._maybe_get_const(size, 't').dim() == 0))
if is_scalar:
size = unsqueeze(g, size, 0)
size = [size for i in range(input.type().dim() - 2)]
size = g.op("Concat", *size, axis_i=0)
size = g.op("Concat", input_size, size, axis_i=0)
scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
return g.op("Resize",
input,
roi,
scales,
size,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=mode, # nearest, linear, or cubic
nearest_mode_s="floor")
else: # if not sym_help._is_none(scales)
scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim())
return g.op("Resize",
input,
roi,
scales,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=mode, # nearest, linear, or cubic
nearest_mode_s="floor") # only valid when mode="nearest"
@parse_args('v', 'i', 'v', 'v')
def gather(g, self, dim, index, sparse_grad=False):
if sym_help._maybe_get_const(sparse_grad, 'i'):
return _unimplemented("gather", "sparse_grad == True")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, dim, index, sparse_grad, operator_s="gather")
return g.op("GatherElements", self, index, axis_i=dim)
@parse_args('v', 'i', 'v', 'v')
def scatter(g, self, dim, index, src):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, dim, index, src, operator_s="scatter")
return g.op("ScatterElements", self, index, src, axis_i=dim)
@parse_args('v', 'i', 'none')
def cumsum(g, self, dim, dtype=None):
dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
csum = g.op("CumSum", self, dim_tensor)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
csum = g.op("Cast", csum, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return csum
def masked_select(g, self, mask):
from torch.onnx.symbolic_opset9 import nonzero, expand_as
index = nonzero(g, expand_as(g, mask, self))
return g.op('GatherND', self, index)
def masked_scatter(g, self, mask, source):
from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size
index = nonzero(g, expand_as(g, mask, self))
# NOTE: source can have more elements than needed.
# It could also have arbitrary shape.
# This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
source = view(g, source, torch.LongTensor([-1]))
source = sym_help._slice_helper(g, source,
axes=torch.LongTensor([0]),
starts=torch.LongTensor([0]),
ends=size(g, index, torch.LongTensor([0])),
dynamic_slice=True)
return g.op('ScatterND', self, index, source)
@parse_args('v', 'i', 'i', 'i')
def _unique2(g, self, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op("Unique", self, sorted_i=sorted, outputs=4)
return u, inverse_indices, counts
@parse_args('v', 'i', 'i', 'i', 'i')
def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op("Unique", self, axis_i=dim, sorted_i=sorted, outputs=4)
return u, inverse_indices, counts
@parse_args('v', 'v', 'i', 'i', 'i', 'none')
def topk(g, self, k, dim, largest, sorted, out=None):
return sym_help._topk_helper(g, self, k, dim, largest=largest, sorted=sorted, out=out)
@parse_args('v', 'i', 'i', 'none')
def sort(g, self, dim, decending, out=None):
return sym_help._sort_helper(g, self, dim, decending=decending, out=out)
def round(g, self):
return g.op("Round", self)
def det(g, self):
return g.op("Det", self)
def arange(g, *args):
def _get_arange_dtype(dtype):
dtype = sym_help._maybe_get_const(dtype, 'i')
return dtype
if len(args) == 5:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[1])
type, end, start, step = sym_help._arange_cast_helper(g, end=args[0], dtype=dtype)
start_default = g.op("Constant", value_t=torch.tensor(0, dtype=sym_help.scalar_type_to_pytorch_type[type]))
delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type]))
arange_tensor = g.op("Range", start_default, end, delta_default)
elif len(args) == 6:
# aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[2])
type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], dtype=dtype)
delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type]))
arange_tensor = g.op("Range", start, end, delta_default)
elif len(args) == 7:
# aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
dtype = _get_arange_dtype(args[3])
type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], step=args[2], dtype=dtype)
arange_tensor = g.op("Range", start, end, step)
else:
raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.")
return arange_tensor
@parse_args('v', 'i')
def _dim_arange(g, like, dim):
like_shape = g.op('Shape', like)
stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("_caffe2::Range", stop)
return arange(g, stop, 4, None, None, None)
def size(g, self, dim):
return sym_help._size_helper(g, self, dim)
def squeeze(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
return g.op("Squeeze", self, axes_i=dims)
@parse_args('v', 'i')
def unsqueeze(g, self, dim):
return g.op("Unsqueeze", self, axes_i=[dim])
def mm(g, self, other):
return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
def index_fill(g, self, dim, index, value):
dim_value = sym_help._parse_arg(dim, 'i')
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill")
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, self)
expanded_value = expand(g, value, expanded_index_shape, None)
return scatter(g, self, dim, expanded_index, expanded_value)
def index_copy(g, self, dim, index, source):
dim_value = sym_help._parse_arg(dim, 'i')
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, index, source, dim_i=dim_value, operator_s="index_copy")
expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
return scatter(g, self, dim, expanded_index, source)