| from __future__ import absolute_import, division, print_function, unicode_literals |
| |
| from sys import maxsize |
| |
| import torch |
| import torch.onnx.symbolic_helper as sym_help |
| import warnings |
| import numpy |
| |
| from torch.onnx.symbolic_helper import parse_args, _unimplemented |
| from torch.onnx.symbolic_opset9 import expand |
| from torch.nn.modules.utils import _single, _pair, _triple |
| |
| |
| # EDITING THIS FILE? READ THIS FIRST! |
| # see Note [Edit Symbolic Files] in symbolic_helper.py |
| |
| # This file exports ONNX ops for opset 11 |
| |
| |
| @parse_args('v', 'f', 'f') |
| def hardtanh(g, self, min_val, max_val): |
| dtype = self.type().scalarType() |
| if dtype is not None: |
| dtype = 6 # float |
| else: |
| dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) |
| min_val = g.op("Constant", value_t=torch.tensor(min_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype])) |
| max_val = g.op("Constant", value_t=torch.tensor(max_val, dtype=sym_help.scalar_type_to_pytorch_type[dtype])) |
| return g.op("Clip", self, min_val, max_val) |
| |
| |
| def clamp(g, self, min, max): |
| dtype = self.type().scalarType() |
| |
| def _cast_if_not_none(tensor, dtype): |
| if tensor is not None and not sym_help._is_none(tensor): |
| return g.op("Cast", tensor, to_i=sym_help.cast_pytorch_to_onnx[dtype]) |
| else: |
| return tensor |
| |
| if dtype is not None: |
| min = _cast_if_not_none(min, dtype) |
| max = _cast_if_not_none(max, dtype) |
| return g.op("Clip", self, min, max) |
| |
| |
| def index_put(g, self, indices_list_value, values, accumulate=False): |
| indices_list = sym_help._unpack_list(indices_list_value) |
| if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: |
| args = [self] + indices_list + [values, accumulate] |
| return g.op("ATen", *args, operator_s='index_put') |
| |
| from torch.onnx.symbolic_opset9 import add, expand |
| accumulate = sym_help._parse_arg(accumulate, 'b') |
| |
| index = indices_list[0] |
| |
| if len(indices_list) > 1: |
| for ind in indices_list[1:]: |
| index = add(g, index, ind) |
| broadcast_index_shape = g.op("Shape", index) |
| indices_list = [ |
| g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list |
| ] |
| index = g.op("Concat", *indices_list, axis_i=-1) |
| else: |
| broadcast_index_shape = g.op("Shape", index) |
| index = g.op("Unsqueeze", index, axes_i=[-1]) |
| sub_data_shape = sym_help._slice_helper( |
| g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) |
| values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) |
| values = g.op("Reshape", values, values_shape) |
| |
| if accumulate: |
| dtype = self.type().scalarType() |
| dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) |
| dtype = sym_help.scalar_type_to_pytorch_type[dtype] |
| zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) |
| result = g.op("ScatterND", zeros, index, values) |
| result = add(g, self, result) |
| else: |
| result = g.op("ScatterND", self, index, values) |
| |
| return result |
| |
| |
| @parse_args('v', 'i') |
| def pixel_shuffle(g, self, upscale_factor): |
| dims = self.type().sizes() |
| if len(dims) != 4: |
| return _unimplemented("pixel_shuffle", "only support 4d input") |
| return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") |
| |
| |
| def _interpolate(name, dim, interpolate_mode): |
| def symbolic_fn(g, input, output_size, *args): |
| scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args) |
| align_corners = sym_help._maybe_get_scalar(align_corners) |
| coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \ |
| else "align_corners" if align_corners else "pytorch_half_pixel" |
| empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) |
| |
| if scales is None: |
| input_size = g.op("Shape", input) |
| input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) |
| output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"]) |
| output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) |
| scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) |
| return g.op("Resize", |
| input, |
| empty_tensor, # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize" |
| scales, # scales is not needed since we are sending out_size |
| output_size, |
| coordinate_transformation_mode_s=coordinate_transformation_mode, |
| cubic_coeff_a_f=-0.75, # only valid when mode="cubic" |
| mode_s=interpolate_mode, # nearest, linear, or cubic |
| nearest_mode_s="floor") # only valid when mode="nearest" |
| else: |
| return g.op("Resize", |
| input, |
| empty_tensor, # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize" |
| scales, # scales is not needed since we are sending out_size |
| coordinate_transformation_mode_s=coordinate_transformation_mode, |
| cubic_coeff_a_f=-0.75, # only valid when mode="cubic" |
| mode_s=interpolate_mode, # nearest, linear, or cubic |
| nearest_mode_s="floor") # only valid when mode="nearest" |
| return symbolic_fn |
| |
| |
| upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest") |
| upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest") |
| upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest") |
| upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear") |
| upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear") |
| upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear") |
| upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, "cubic") |
| |
| |
| def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor): |
| mode = sym_help._maybe_get_const(mode, 's') |
| if 'linear' in mode: |
| mode = 'linear' |
| if 'cubic' in mode: |
| mode = 'cubic' |
| align_corners = sym_help._maybe_get_const(align_corners, 'b') |
| align_corners = False if not isinstance(align_corners, bool) else align_corners |
| coordinate_transformation_mode = "asymmetric" if mode == "nearest" \ |
| else "align_corners" if align_corners else "pytorch_half_pixel" |
| # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize" |
| roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) |
| |
| if not sym_help._is_none(size) : |
| input_size = g.op("Shape", input) |
| input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) |
| # in some cases size is not a packed list but size is a scalar |
| # We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0) |
| # but this information is not always available. Try to get the dim, |
| # and if not assume that it is not a scalar. |
| try: |
| is_scalar = not sym_help._is_packed_list(size) and ((sym_help._maybe_get_const(size, 't').dim() == 0)) |
| except AttributeError: |
| is_scalar = not sym_help._is_packed_list(size) |
| if not is_scalar: |
| warnings.warn("Cannot verify if the output_size is a scalar " |
| "while exporting interpolate. Assuming that it is not a scalar.") |
| |
| if is_scalar: |
| if not input.type().dim(): |
| return sym_help._unimplemented("interpolate (with a scalar output_size)", |
| "missing input shape (try giving an array of output_size values)") |
| size = unsqueeze(g, size, 0) |
| size = [size for i in range(input.type().dim() - 2)] |
| size = g.op("Concat", *size, axis_i=0) |
| size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long']) |
| size = g.op("Concat", input_size, size, axis_i=0) |
| scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) |
| return g.op("Resize", |
| input, |
| roi, |
| scales, |
| size, |
| coordinate_transformation_mode_s=coordinate_transformation_mode, |
| cubic_coeff_a_f=-0.75, # only valid when mode="cubic" |
| mode_s=mode, # nearest, linear, or cubic |
| nearest_mode_s="floor") |
| else: # if not sym_help._is_none(scales) |
| if not input.type().dim(): |
| return sym_help._unimplemented("interpolate (with scales)", "missing input shape") |
| scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim()) |
| return g.op("Resize", |
| input, |
| roi, |
| scales, |
| coordinate_transformation_mode_s=coordinate_transformation_mode, |
| cubic_coeff_a_f=-0.75, # only valid when mode="cubic" |
| mode_s=mode, # nearest, linear, or cubic |
| nearest_mode_s="floor") # only valid when mode="nearest" |
| |
| @parse_args('v', 'i', 'v', 'v') |
| def gather(g, self, dim, index, sparse_grad=False): |
| if sym_help._maybe_get_const(sparse_grad, 'i'): |
| return _unimplemented("gather", "sparse_grad == True") |
| if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: |
| return g.op("ATen", self, dim, index, sparse_grad, operator_s="gather") |
| return g.op("GatherElements", self, index, axis_i=dim) |
| |
| |
| @parse_args('v', 'i', 'v', 'v') |
| def scatter(g, self, dim, index, src): |
| if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: |
| return g.op("ATen", self, dim, index, src, operator_s="scatter") |
| return g.op("ScatterElements", self, index, src, axis_i=dim) |
| |
| |
| @parse_args('v', 'i', 'none') |
| def cumsum(g, self, dim, dtype=None): |
| dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) |
| csum = g.op("CumSum", self, dim_tensor) |
| if dtype and dtype.node().kind() != 'prim::Constant': |
| parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype') |
| csum = g.op("Cast", csum, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) |
| return csum |
| |
| |
| def masked_select(g, self, mask): |
| from torch.onnx.symbolic_opset9 import nonzero, expand_as |
| index = nonzero(g, expand_as(g, mask, self)) |
| return g.op('GatherND', self, index) |
| |
| |
| def masked_scatter(g, self, mask, source): |
| from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size |
| index = nonzero(g, expand_as(g, mask, self)) |
| # NOTE: source can have more elements than needed. |
| # It could also have arbitrary shape. |
| # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. |
| source = view(g, source, torch.LongTensor([-1])) |
| source = sym_help._slice_helper(g, source, |
| axes=torch.LongTensor([0]), |
| starts=torch.LongTensor([0]), |
| ends=size(g, index, torch.LongTensor([0])), |
| dynamic_slice=True) |
| return g.op('ScatterND', self, index, source) |
| |
| |
| def _len(g, self): |
| return g.op("SequenceLength", self) |
| |
| |
| def __getitem_(g, self, i): |
| if self.type().isSubtypeOf(torch._C.ListType.ofTensors()): |
| # SequenceAt requires that the input be a List of Tensors |
| return g.op("SequenceAt", self, i) |
| else: |
| from torch.onnx.symbolic_opset9 import __getitem_ as getitem |
| return getitem(g, self, i) |
| |
| |
| def append(g, self, tensor): |
| return g.op("SequenceInsert", self, tensor) |
| |
| |
| def insert(g, self, pos, tensor): |
| return g.op("SequenceInsert", self, tensor, pos) |
| |
| |
| def pop(g, tensor_list, dim): |
| return g.op("SequenceErase", tensor_list, dim) |
| |
| |
| def cat(g, tensor_list, dim): |
| if sym_help._is_packed_list(tensor_list): |
| from torch.onnx.symbolic_opset9 import cat as cat_opset9 |
| return cat_opset9(g, tensor_list, dim) |
| else: |
| dim = sym_help._get_const(dim, 'i', 'dim') |
| return g.op("ConcatFromSequence", tensor_list, axis_i=dim) |
| |
| |
| def stack(g, tensor_list, dim): |
| if sym_help._is_packed_list(tensor_list): |
| from torch.onnx.symbolic_opset9 import stack as stack_opset9 |
| return stack_opset9(g, tensor_list, dim) |
| else: |
| dim = sym_help._get_const(dim, 'i', 'dim') |
| return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) |
| |
| |
| @parse_args('v', 'i', 'i', 'i') |
| def _unique2(g, self, sorted, return_inverse, return_counts): |
| u, indices, inverse_indices, counts = g.op("Unique", self, sorted_i=sorted, outputs=4) |
| return u, inverse_indices, counts |
| |
| |
| def _avg_pool(name, tuple_fn): |
| @parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none') |
| def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None): |
| padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name) |
| if not stride: |
| stride = kernel_size |
| if count_include_pad: |
| input = g.op("Pad", input, |
| g.op("Constant", value_t=torch.tensor(((0,) * 2 + padding) * 2)), mode_s='constant') |
| padding = (0,) * len(padding) |
| output = g.op("AveragePool", input, |
| kernel_shape_i=tuple_fn(kernel_size), |
| strides_i=tuple_fn(stride), |
| pads_i=padding * 2, |
| ceil_mode_i=ceil_mode) |
| return output |
| return symbolic_fn |
| |
| |
| avg_pool1d = _avg_pool('avg_pool1d', _single) |
| avg_pool2d = _avg_pool('avg_pool2d', _pair) |
| avg_pool3d = _avg_pool('avg_pool3d', _triple) |
| |
| |
| @parse_args('v', 'i', 'i', 'i', 'i') |
| def unique_dim(g, self, dim, sorted, return_inverse, return_counts): |
| u, indices, inverse_indices, counts = g.op("Unique", self, axis_i=dim, sorted_i=sorted, outputs=4) |
| return u, inverse_indices, counts |
| |
| |
| @parse_args('v', 'v', 'i', 'i', 'i', 'none') |
| def topk(g, self, k, dim, largest, sorted, out=None): |
| return sym_help._topk_helper(g, self, k, dim, largest=largest, sorted=sorted, out=out) |
| |
| |
| @parse_args('v', 'i', 'i', 'none') |
| def sort(g, self, dim, decending, out=None): |
| return sym_help._sort_helper(g, self, dim, decending=decending, out=out) |
| |
| |
| def round(g, self): |
| return g.op("Round", self) |
| |
| |
| @parse_args('v', 'v', 'i') |
| def split_with_sizes(g, self, split_sizes, dim): |
| if sym_help._is_value(split_sizes) and split_sizes.node().kind() == 'prim::ListConstruct': |
| return g.op("SplitToSequence", self, split_sizes, axis_i=dim) |
| else: |
| return torch.onnx.symbolic_opset9.split_with_sizes(g, self, split_sizes, dim) |
| |
| |
| # Generate paddings in ONNX order based on pad in pytorch. |
| # Arguments: |
| # dim: the dimension of the tensor. |
| # pad: the paddings in pytorch. |
| # The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, |
| # where m is in range [0, n]. |
| def _prepare_onnx_paddings(g, dim, pad): |
| # The desired order of paddings is |
| # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. |
| # n is the dimension of input. |
| # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning |
| pad_len = torch.onnx.symbolic_opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) |
| # Set extension = [0] * (dim * 2 - len(pad)) |
| extension = g.op("Sub", g.op("Mul", g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)), |
| g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len) |
| # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] |
| # Currently ONNX only supports int64 type for Pad |
| pad = g.op("Cast", pad, to_i=sym_help.cast_pytorch_to_onnx['Long']) |
| paddings = g.op("Concat", pad, g.op("ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)), axis_i=0) |
| # Reshape and reverse order and collate first beginnings and then ends |
| # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], |
| # [..., 0, dim_n-1_end, dim_n_end]] |
| # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] |
| paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))) |
| paddings = g.op("Transpose", torch.onnx.symbolic_opset10.flip(g, paddings, [0]), perm_i=[1, 0]) |
| paddings = g.op("Reshape", paddings, g.op("Constant", value_t=torch.tensor([-1]))) |
| padding_c = g.op("Cast", paddings, to_i=sym_help.cast_pytorch_to_onnx['Long']) |
| return padding_c |
| |
| |
| def constant_pad_nd(g, input, padding, value=None): |
| mode = "constant" |
| value = sym_help._maybe_get_scalar(value) |
| value = sym_help._if_scalar_type_as(g, value, input) |
| pad = _prepare_onnx_paddings(g, input.type().dim(), padding) |
| return g.op("Pad", input, pad, value, mode_s=mode) |
| |
| |
| def reflection_pad(g, input, padding): |
| mode = "reflect" |
| paddings = _prepare_onnx_paddings(g, input.type().dim(), padding) |
| return g.op("Pad", input, paddings, mode_s=mode) |
| |
| |
| def replication_pad(g, input, padding): |
| mode = "edge" |
| paddings = _prepare_onnx_paddings(g, input.type().dim(), padding) |
| return g.op("Pad", input, paddings, mode_s=mode) |
| |
| |
| reflection_pad1d = reflection_pad |
| reflection_pad2d = reflection_pad |
| reflection_pad3d = reflection_pad |
| replication_pad1d = replication_pad |
| replication_pad2d = replication_pad |
| replication_pad3d = replication_pad |
| |
| |
| def det(g, self): |
| return g.op("Det", self) |
| |
| |
| def logdet(g, input): |
| from torch.onnx.symbolic_opset9 import log |
| return log(g, det(g, input)) |
| |
| |
| def arange(g, *args): |
| def _get_arange_dtype(dtype): |
| dtype = sym_help._maybe_get_const(dtype, 'i') |
| return dtype |
| |
| if len(args) == 5: |
| # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) |
| dtype = _get_arange_dtype(args[1]) |
| type, end, start, step = sym_help._arange_cast_helper(g, end=args[0], dtype=dtype) |
| start_default = g.op("Constant", value_t=torch.tensor(0, dtype=sym_help.scalar_type_to_pytorch_type[type])) |
| delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type])) |
| arange_tensor = g.op("Range", start_default, end, delta_default) |
| elif len(args) == 6: |
| # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) |
| dtype = _get_arange_dtype(args[2]) |
| type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], dtype=dtype) |
| delta_default = g.op("Constant", value_t=torch.tensor(1, dtype=sym_help.scalar_type_to_pytorch_type[type])) |
| arange_tensor = g.op("Range", start, end, delta_default) |
| elif len(args) == 7: |
| # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) |
| dtype = _get_arange_dtype(args[3]) |
| type, end, start, step = sym_help._arange_cast_helper(g, start=args[0], end=args[1], step=args[2], dtype=dtype) |
| arange_tensor = g.op("Range", start, end, step) |
| else: |
| raise NotImplementedError("Unknown aten::arange signature taking " + str(len(args)) + " arguments.") |
| return arange_tensor |
| |
| |
| @parse_args('v', 'i') |
| def _dim_arange(g, like, dim): |
| like_shape = g.op('Shape', like) |
| stop = g.op("Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0) |
| if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: |
| return g.op("_caffe2::Range", stop) |
| return arange(g, stop, 4, None, None, None) |
| |
| |
| def size(g, self, dim): |
| return sym_help._size_helper(g, self, dim) |
| |
| |
| def squeeze(g, self, dim=None): |
| if dim is None: |
| dims = [] |
| for i, size in enumerate(self.type().sizes()): |
| if size == 1: |
| dims.append(i) |
| else: |
| dims = [sym_help._get_const(dim, 'i', 'dim')] |
| return g.op("Squeeze", self, axes_i=dims) |
| |
| |
| @parse_args('v', 'i') |
| def unsqueeze(g, self, dim): |
| return g.op("Unsqueeze", self, axes_i=[dim]) |
| |
| |
| def mm(g, self, other): |
| return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) |
| |
| |
| def index_fill(g, self, dim, index, value): |
| dim_value = sym_help._parse_arg(dim, 'i') |
| if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: |
| return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill") |
| expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) |
| value = sym_help._maybe_get_scalar(value) |
| value = sym_help._if_scalar_type_as(g, value, self) |
| expanded_value = expand(g, value, expanded_index_shape, None) |
| return scatter(g, self, dim, expanded_index, expanded_value) |
| |
| |
| def index_copy(g, self, dim, index, source): |
| dim_value = sym_help._parse_arg(dim, 'i') |
| if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: |
| return g.op("ATen", self, index, source, dim_i=dim_value, operator_s="index_copy") |
| expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) |
| return scatter(g, self, dim, expanded_index, source) |
| |
| |
| def __rshift_(g, self, other): |
| # make sure to cast other to self's type |
| # (when self is long, make sure that other is not float) |
| if other.type().scalarType() != self.type().scalarType(): |
| other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) |
| |
| if self.type().scalarType() == 'Byte': |
| return g.op('BitShift', self, other, direction_s="RIGHT") |
| |
| two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32)) |
| # exponent (same type as self) has to be float or double in onnx::Pow |
| if not sym_help._is_fp(self): |
| other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float']) |
| two_pow = g.op('Pow', two, other) |
| |
| rshift = g.op('Div', self, two_pow) |
| return rshift |
| |
| |
| def __lshift_(g, self, other): |
| # make sure to cast other to self's type |
| # (when self is long, make sure that other is not float) |
| if other.type().scalarType() != self.type().scalarType(): |
| other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) |
| |
| if self.type().scalarType() == 'Byte': |
| return g.op('BitShift', self, other, direction_s="LEFT") |
| |
| two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32)) |
| # exponent (same type as self) has to be float or double in onnx::Pow |
| if not sym_help._is_fp(self): |
| other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float']) |
| two_pow = g.op('Pow', two, other) |
| |
| lshift = g.op('Mul', self, two_pow) |
| return lshift |
| |
| |
| def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d, padding_d, stride_d): |
| # Input is always 4-D (N, C, H, W) |
| # Calculate indices of sliding blocks along spatial dimension |
| # Slide kernel over input each dim d: |
| # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) |
| # with steps = stride |
| |
| blocks_d = g.op("Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2))) |
| blocks_d = g.op("Sub", blocks_d, g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1)))) |
| |
| # Stride kernel over input and find starting indices along dim d |
| blocks_d_indices = g.op("Range", g.op("Constant", value_t=torch.tensor(0)), |
| blocks_d, g.op("Constant", value_t=torch.tensor(stride_d))) |
| |
| # Apply dilation on kernel and find its indices along dim d |
| kernel_grid = numpy.arange(0, kernel_size_d * dilation_d, dilation_d) |
| kernel_grid = g.op("Constant", value_t=torch.tensor([kernel_grid])) |
| |
| # Broadcast and add kernel staring positions (indices) with |
| # kernel_grid along dim d, to get block indices along dim d |
| blocks_d_indices = g.op('Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1] |
| kernel_mask = g.op('Reshape', kernel_grid, g.op('Constant', value_t=torch.tensor([-1, 1]))) |
| block_mask = g.op("Add", blocks_d_indices, kernel_mask) |
| |
| return block_mask |
| |
| |
| def _get_im2col_padded_input(g, input, padding_h, padding_w): |
| # Input is always 4-D tensor (N, C, H, W) |
| # Padding tensor has the following format: (padding_h, padding_w) |
| # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) |
| pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) |
| return g.op("Pad", input, pad) |
| |
| |
| def _get_im2col_output_shape(g, input, kernel_h, kernel_w): |
| batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) |
| channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) |
| channel_unfolded = g.op("Mul", channel_dim, |
| g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))) |
| |
| return g.op("Concat", |
| g.op("Unsqueeze", batch_dim, axes_i=[0]), |
| g.op("Unsqueeze", channel_unfolded, axes_i=[0]), |
| g.op("Constant", value_t=torch.tensor([-1])), axis_i=0) |
| |
| |
| @parse_args('v', 'is', 'is', 'is', 'is') |
| def im2col(g, input, kernel_size, dilation, padding, stride): |
| # Input is always 4-D tensor (N, C, H, W) |
| # All other args are int[2] |
| |
| input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) |
| input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) |
| |
| stride_h, stride_w = stride[0], stride[1] |
| padding_h, padding_w = padding[0], padding[1] |
| dilation_h, dilation_w = dilation[0], dilation[1] |
| kernel_h, kernel_w = kernel_size[0], kernel_size[1] |
| |
| blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h, dilation_h, padding_h, stride_h) |
| blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w, dilation_w, padding_w, stride_w) |
| |
| output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) |
| padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) |
| |
| # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 |
| # [[[[1., 2., 3.,], |
| # [4., 5., 6.,], |
| # [7., 8., 9.,]]]] |
| # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: |
| # [[[[[1., 2., 3.], |
| # [4., 5., 6.]], |
| # [[4., 5., 6.], |
| # [7., 8., 9.]]]]] |
| # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: |
| # [[[[[[1., 2.], |
| # [4., 5.]], |
| # [[2., 3.], |
| # [5., 6]]], |
| # [[[4., 5.], |
| # [7., 8.]], |
| # [[5., 6.], |
| # [8., 9.]]]]]] |
| # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: |
| # [[[1., 2., 4., 5.], |
| # [2., 3., 5., 6.], |
| # [4., 5., 7., 8.], |
| # [5., 6., 8., 9.]]] |
| output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) |
| output = g.op("Gather", output, blocks_col_indices, axis_i=4) |
| output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) |
| return g.op("Reshape", output, output_shape) |
| |
| |
| @parse_args('v', 'i', 'i') |
| def flatten(g, input, start_dim, end_dim): |
| dim = input.type().dim() |
| # use ONNX's Flatten operator for cases where the output shape is 2D |
| if start_dim == 1: |
| if (end_dim == -1 or (end_dim is not None and end_dim == dim - 1)): |
| return g.op("Flatten", input, axis_i=start_dim) |
| elif start_dim == 0: |
| if (end_dim == -2 or (end_dim is not None and end_dim == dim - 2)): |
| return g.op("Flatten", input, axis_i=end_dim + 1) |
| # use Reshape for cases where the output shape is not 2D |
| if not input.isCompleteTensor(): |
| return _unimplemented("flatten", |
| "input size not accessible " |
| "(consider using reshape op instead of flatten op to export to ONNX)") |
| # if end_dim is negative add dim |
| if end_dim < 0 : |
| end_dim = dim + end_dim |
| input_dims = input.type().sizes() |
| output_dims = [] |
| for i in range(0, dim): |
| if start_dim < i and end_dim >= i: |
| output_dims[start_dim] = output_dims[start_dim] * input_dims[i] |
| else: |
| output_dims.append(input_dims[i]) |
| shape = g.op("Constant", value_t=torch.LongTensor(output_dims)) |
| from torch.onnx.symbolic_opset9 import _reshape_from_tensor |
| p = _reshape_from_tensor(g, input, shape) |
| return p |