blob: ee75d7e757fd7860ecd22460eeafcaf08657f1f5 [file] [log] [blame]
import torch
from torch.nn.modules.utils import _single, _pair, _triple
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _black_list_in_opset
import torch.onnx.symbolic_opset9
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 10
# Opset 10 is supported by ONNX release 1.5.0
# release on 04/24/19
# Blacklist operators for this opset version.
# These operators have been updated in ONNX but not re-implemented here.
# It is very important to blacklist these operators to avoid exporting
# models with mixed versions of operators.
# TODO : add support for the blacklisted operators in black_listed_operators
black_listed_operators = ["flip",
"slice",
"upsample_nearest2d", "upsample_bilinear2d",
"dropout", "feature_dropout", "alpha_dropout", "feature_alpha_dropout",
"dropout_", "feature_dropout_", "alpha_dropout_", "feature_alpha_dropout_"]
for black_listed_op in black_listed_operators:
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
# Add new operator here
@parse_args('v', 'i', 'i', 'i', 'i')
def topk(g, self, k, dim, largest, sorted, out=None):
if out is not None:
_unimplemented("TopK", "Out parameter is not supported for topk")
if not largest:
_unimplemented("TopK", "Ascending TopK is not supported")
k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64))
from torch.onnx.symbolic_opset9 import unsqueeze
k = unsqueeze(g, k, 0)
return g.op("TopK", self, k, axis_i=dim, outputs=2)
def _max_pool(name, tuple_fn, ndims, return_indices):
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if not stride:
stride = kernel_size
kwargs = {
'kernel_shape_i': tuple_fn(kernel_size),
'pads_i': tuple_fn(padding) * 2,
'strides_i': tuple_fn(stride),
'ceil_mode_i': ceil_mode,
}
if set(tuple_fn(dilation)) != {1}:
kwargs['dilations_i'] = tuple_fn(dilation)
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and substract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
if return_indices:
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
_, flattened_indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=[1 for _ in range(ndims)],
strides_i=[1 for _ in range(ndims)])
# convert indices to have non-flattened indices values
s = _slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)],
starts=tuple_fn(0), ends=tuple_fn(1))
indices = sub(g, indices, s)
return r, indices
else:
r = g.op("MaxPool", input, outputs=1, **kwargs)
return r
return symbolic_fn
max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i')
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
if count_include_pad:
input = g.op("Pad", input,
pads_i=((0,) * 2 + padding) * 2,
mode_s='constant',
value_f=0.)
padding = (0,) * len(padding)
output = g.op("AveragePool", input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding * 2,
ceil_mode_i=ceil_mode)
return output
return symbolic_fn
avg_pool1d = _avg_pool('avg_pool1d', _single)
avg_pool2d = _avg_pool('avg_pool2d', _pair)
avg_pool3d = _avg_pool('avg_pool3d', _triple)