|  | import functools | 
|  | import sys | 
|  | from typing import Optional, Tuple | 
|  |  | 
|  | import torch | 
|  | from torch._C import _onnx as _C_onnx | 
|  | from torch.onnx import ( | 
|  | _type_utils, | 
|  | errors, | 
|  | symbolic_helper, | 
|  | symbolic_opset9 as opset9, | 
|  | utils, | 
|  | ) | 
|  | from torch.onnx._internal import _beartype, jit_utils, registration | 
|  |  | 
|  |  | 
|  | # EDITING THIS FILE? READ THIS FIRST! | 
|  | # see Note [Edit Symbolic Files] in README.md | 
|  |  | 
|  | # This file exports ONNX ops for opset 12 | 
|  |  | 
|  | __all__ = [ | 
|  | "argmax", | 
|  | "argmin", | 
|  | "binary_cross_entropy_with_logits", | 
|  | "celu", | 
|  | "cross_entropy_loss", | 
|  | "dropout", | 
|  | "einsum", | 
|  | "ge", | 
|  | "le", | 
|  | "native_dropout", | 
|  | "nll_loss", | 
|  | "nll_loss2d", | 
|  | "nll_loss_nd", | 
|  | "outer", | 
|  | "pow", | 
|  | "tensordot", | 
|  | "unfold", | 
|  | ] | 
|  |  | 
|  | _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12) | 
|  |  | 
|  |  | 
|  | @_beartype.beartype | 
|  | def _einsum_helper(g: jit_utils.GraphContext, equation, tensors): | 
|  | if not tensors: | 
|  | raise RuntimeError("Einsum inputs are empty.") | 
|  | # ONNX does not support bool for Einsum inputs. | 
|  | if symbolic_helper._is_bool(tensors[0]): | 
|  | tensors = [ | 
|  | g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64) | 
|  | for tensor in tensors | 
|  | ] | 
|  | return g.op( | 
|  | "Cast", | 
|  | g.op("Einsum", *tensors, equation_s=equation), | 
|  | to_i=_C_onnx.TensorProtoDataType.BOOL, | 
|  | ) | 
|  | else: | 
|  | return g.op("Einsum", *tensors, equation_s=equation) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::einsum") | 
|  | @symbolic_helper.parse_args("s", "v", "is") | 
|  | @_beartype.beartype | 
|  | def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None): | 
|  | tensors = symbolic_helper._unpack_list(tensor_list) | 
|  | return _einsum_helper(g, equation, tensors) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::outer") | 
|  | @symbolic_helper.parse_args("v", "v") | 
|  | @_beartype.beartype | 
|  | def outer(g: jit_utils.GraphContext, input, other): | 
|  | # make sure to cast other to self's type | 
|  | if other.type().scalarType() != input.type().scalarType(): | 
|  | other = g.op( | 
|  | "Cast", | 
|  | other, | 
|  | to_i=_type_utils.JitScalarType.from_name( | 
|  | input.type().scalarType() | 
|  | ).onnx_type(), | 
|  | ) | 
|  | return _einsum_helper(g, "i,j->ij", [input, other]) | 
|  |  | 
|  |  | 
|  | @_beartype.beartype | 
|  | def _dropout_returns_masked_input_and_mask( | 
|  | g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool | 
|  | ) -> Tuple[torch._C.Value, Optional[torch._C.Value]]: | 
|  | symbolic_helper.check_training_mode(train, "dropout") | 
|  | # In eval mode, dropout is non-op. That is, if the node's | 
|  | # train param is set to False, dropout just returns its inputs. | 
|  | if not train: | 
|  | return input, None | 
|  | p = g.op("Constant", value_t=torch.tensor(p)) | 
|  | t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool)) | 
|  | r, mask = g.op("Dropout", input, p, t, outputs=2) | 
|  | return r, mask | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::dropout") | 
|  | @symbolic_helper.parse_args("v", "f", "b") | 
|  | @_beartype.beartype | 
|  | def dropout(g: jit_utils.GraphContext, input, p, train): | 
|  | masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train) | 
|  | return masked | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::native_dropout") | 
|  | @symbolic_helper.parse_args("v", "f", "b") | 
|  | @_beartype.beartype | 
|  | def native_dropout(g: jit_utils.GraphContext, input, p, train): | 
|  | return _dropout_returns_masked_input_and_mask(g, input, p, train) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::nll_loss") | 
|  | @_beartype.beartype | 
|  | def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index): | 
|  | # none reduction : onnx::Constant[value={0}] | 
|  | # mean reduction : onnx::Constant[value={1}] | 
|  | # sum reduction : onnx::Constant[value={2}] | 
|  | reduction = symbolic_helper._maybe_get_const(reduction, "i") | 
|  | reduction_vals = ["none", "mean", "sum"] | 
|  | reduction = reduction_vals[reduction] | 
|  |  | 
|  | # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. | 
|  | # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). | 
|  | ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") | 
|  | if weight.node().mustBeNone(): | 
|  | nllloss = g.op( | 
|  | "NegativeLogLikelihoodLoss", | 
|  | self, | 
|  | target, | 
|  | reduction_s=reduction, | 
|  | ignore_index_i=ignore_index, | 
|  | ) | 
|  | else: | 
|  | nllloss = g.op( | 
|  | "NegativeLogLikelihoodLoss", | 
|  | self, | 
|  | target, | 
|  | weight, | 
|  | reduction_s=reduction, | 
|  | ignore_index_i=ignore_index, | 
|  | ) | 
|  |  | 
|  | return nllloss | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::nll_loss2d") | 
|  | @_beartype.beartype | 
|  | def nll_loss2d( | 
|  | g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index | 
|  | ): | 
|  | return nll_loss(g, self, target, weight, reduction, ignore_index) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::nll_loss_nd") | 
|  | @_beartype.beartype | 
|  | def nll_loss_nd( | 
|  | g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index | 
|  | ): | 
|  | return nll_loss(g, self, target, weight, reduction, ignore_index) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::cross_entropy_loss") | 
|  | @_beartype.beartype | 
|  | def cross_entropy_loss( | 
|  | g: jit_utils.GraphContext, | 
|  | self, | 
|  | target, | 
|  | weight, | 
|  | reduction, | 
|  | ignore_index, | 
|  | label_smoothing, | 
|  | ): | 
|  | # none reduction : onnx::Constant[value={0}] | 
|  | # mean reduction : onnx::Constant[value={1}] | 
|  | # sum reduction : onnx::Constant[value={2}] | 
|  | reduction = symbolic_helper._maybe_get_const(reduction, "i") | 
|  | reduction_vals = ["none", "mean", "sum"] | 
|  | reduction = reduction_vals[reduction] | 
|  |  | 
|  | label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f") | 
|  | if label_smoothing is not None and label_smoothing > 0.0: | 
|  | raise errors.SymbolicValueError( | 
|  | "Unsupported: ONNX does not support label_smoothing", self | 
|  | ) | 
|  |  | 
|  | # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. | 
|  | # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). | 
|  | ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i") | 
|  | if weight.node().mustBeNone(): | 
|  | celoss = g.op( | 
|  | "SoftmaxCrossEntropyLoss", | 
|  | self, | 
|  | target, | 
|  | reduction_s=reduction, | 
|  | ignore_index_i=ignore_index, | 
|  | ) | 
|  | else: | 
|  | celoss = g.op( | 
|  | "SoftmaxCrossEntropyLoss", | 
|  | self, | 
|  | target, | 
|  | weight, | 
|  | reduction_s=reduction, | 
|  | ignore_index_i=ignore_index, | 
|  | ) | 
|  |  | 
|  | return celoss | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::binary_cross_entropy_with_logits") | 
|  | @symbolic_helper.parse_args("v", "v", "v", "v", "i") | 
|  | @_beartype.beartype | 
|  | def binary_cross_entropy_with_logits( | 
|  | g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction | 
|  | ): | 
|  | p = g.op("Constant", value_t=torch.tensor([1])) | 
|  | sig_x = opset9.sigmoid(g, input) | 
|  | log_sig_x = opset9.log(g, sig_x) | 
|  | sub_1_x = opset9.sub(g, p, sig_x) | 
|  | sub_1_y = opset9.sub(g, p, target) | 
|  | log_1_x = opset9.log(g, sub_1_x) | 
|  | if pos_weight is None or symbolic_helper._is_none(pos_weight): | 
|  | output = opset9.neg( | 
|  | g, | 
|  | opset9.add( | 
|  | g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x) | 
|  | ), | 
|  | ) | 
|  | else: | 
|  | output = opset9.neg( | 
|  | g, | 
|  | opset9.add( | 
|  | g, | 
|  | opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), | 
|  | opset9.mul(g, sub_1_y, log_1_x), | 
|  | ), | 
|  | ) | 
|  |  | 
|  | if weight is not None and not symbolic_helper._is_none(weight): | 
|  | output = opset9.mul(g, weight, output) | 
|  |  | 
|  | reduction = symbolic_helper._maybe_get_const(reduction, "i") | 
|  | if reduction == 0: | 
|  | return output | 
|  | elif reduction == 1: | 
|  | return g.op("ReduceMean", output, keepdims_i=0) | 
|  | elif reduction == 2: | 
|  | return g.op("ReduceSum", output, keepdims_i=0) | 
|  | else: | 
|  | return symbolic_helper._onnx_unsupported( | 
|  | "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", | 
|  | input, | 
|  | ) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::celu") | 
|  | @_beartype.beartype | 
|  | def celu(g: jit_utils.GraphContext, self, alpha): | 
|  | alpha = symbolic_helper._maybe_get_const(alpha, "f") | 
|  | # if the input is of type double cast it to float | 
|  | if self.type().scalarType() == "Double": | 
|  | self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) | 
|  | out = g.op("Celu", self, alpha_f=alpha) | 
|  | return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE) | 
|  |  | 
|  | return g.op("Celu", self, alpha_f=alpha) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::argmax") | 
|  | @symbolic_helper.parse_args("v", "v", "b") | 
|  | @_beartype.beartype | 
|  | def argmax( | 
|  | g: jit_utils.GraphContext, | 
|  | input: torch._C.Value, | 
|  | dim: torch._C.Value, | 
|  | keepdim: bool, | 
|  | ): | 
|  | return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::argmin") | 
|  | @symbolic_helper.parse_args("v", "v", "b") | 
|  | @_beartype.beartype | 
|  | def argmin( | 
|  | g: jit_utils.GraphContext, | 
|  | input: torch._C.Value, | 
|  | dim: torch._C.Value, | 
|  | keepdim: bool, | 
|  | ): | 
|  | return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::pow") | 
|  | @_beartype.beartype | 
|  | def pow(g: jit_utils.GraphContext, self, exponent): | 
|  | return g.op("Pow", self, exponent) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::ge") | 
|  | @_beartype.beartype | 
|  | def ge(g: jit_utils.GraphContext, input, other): | 
|  | return g.op("GreaterOrEqual", input, other) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::le") | 
|  | @_beartype.beartype | 
|  | def le(g: jit_utils.GraphContext, input, other): | 
|  | return g.op("LessOrEqual", input, other) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::unfold") | 
|  | @symbolic_helper.parse_args("v", "i", "v", "v") | 
|  | @_beartype.beartype | 
|  | def unfold(g: jit_utils.GraphContext, input, dimension, size, step): | 
|  | const_size = symbolic_helper._maybe_get_const(size, "i") | 
|  | const_step = symbolic_helper._maybe_get_const(step, "i") | 
|  | if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value( | 
|  | const_step | 
|  | ): | 
|  | return opset9.unfold(g, input, dimension, const_size, const_step) | 
|  | if symbolic_helper.is_caffe2_aten_fallback(): | 
|  | return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) | 
|  |  | 
|  | sizedim = symbolic_helper._get_tensor_dim_size(input, dimension) | 
|  | if sizedim is not None: | 
|  | low_start = g.op("Constant", value_t=torch.tensor(0)) | 
|  | low_end = g.op("Constant", value_t=torch.tensor(sizedim)) | 
|  | hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) | 
|  | low_indices = g.op("Range", low_start, low_end, step) | 
|  | hi_indices = g.op("Range", size, hi_end, step) | 
|  |  | 
|  | low_size = symbolic_helper._size_helper( | 
|  | g, low_indices, g.op("Constant", value_t=torch.tensor(0)) | 
|  | ) | 
|  | hi_size = symbolic_helper._size_helper( | 
|  | g, hi_indices, g.op("Constant", value_t=torch.tensor(0)) | 
|  | ) | 
|  |  | 
|  | ndim = symbolic_helper._get_tensor_rank(input) | 
|  | assert ndim is not None | 
|  | perm = list(range(0, ndim)) | 
|  | perm.append(perm.pop(dimension)) | 
|  |  | 
|  | unsqueeze_list = [] | 
|  | loop_condition = g.op("Constant", value_t=torch.tensor(1)) | 
|  | loop_condition = g.op( | 
|  | "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL | 
|  | ) | 
|  | loop_len = g.op("Min", low_size, hi_size) | 
|  |  | 
|  | loop, (loop_context,), _ = jit_utils.add_op_with_blocks( | 
|  | g, "Loop", loop_len, loop_condition, n_blocks=1 | 
|  | ) | 
|  |  | 
|  | loop_block = loop_context.block | 
|  | block_input_iter = utils._add_input_to_block(loop_block) | 
|  | # FIXME(justinchuby): cond is unused? | 
|  | cond = utils._add_input_to_block(loop_block) | 
|  |  | 
|  | starts = loop_context.op("Gather", low_indices, block_input_iter) | 
|  | ends = loop_context.op("Gather", hi_indices, block_input_iter) | 
|  | axes = loop_context.op("Constant", value_t=torch.tensor([2])) | 
|  | starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0]) | 
|  | ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0]) | 
|  | stack = loop_context.op("Slice", input, starts, ends, axes) | 
|  |  | 
|  | unsqueeze = symbolic_helper._unsqueeze_helper( | 
|  | loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension] | 
|  | ) | 
|  | unsqueeze_list.append(unsqueeze) | 
|  | concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0) | 
|  |  | 
|  | cond_out = loop_context.op( | 
|  | "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL | 
|  | ) | 
|  | utils._add_output_to_block(loop_block, cond_out) | 
|  | utils._add_output_to_block(loop_block, concat) | 
|  |  | 
|  | loop_output = loop.node().output() | 
|  | perm = [0, 1, 2, 3, 4] | 
|  | perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] | 
|  | transpose = g.op("Transpose", loop_output, perm_i=perm) | 
|  | squeeze = symbolic_helper._squeeze_helper(g, transpose, [0]) | 
|  |  | 
|  | return squeeze | 
|  |  | 
|  | return symbolic_helper._unimplemented("Unfold", "input size not accessible") | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::tensordot") | 
|  | @symbolic_helper.parse_args("v", "v", "is", "is", "v") | 
|  | @_beartype.beartype | 
|  | def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None): | 
|  | if out is not None: | 
|  | symbolic_helper._unimplemented( | 
|  | "Tensordot", "Out parameter is not supported for tensordot." | 
|  | ) | 
|  |  | 
|  | dim_count_a = symbolic_helper._get_tensor_rank(input_a) | 
|  | if dim_count_a is None: | 
|  | raise errors.SymbolicValueError( | 
|  | "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.", | 
|  | input_a, | 
|  | ) | 
|  |  | 
|  | dim_count_b = symbolic_helper._get_tensor_rank(input_b) | 
|  | if dim_count_b is None: | 
|  | raise errors.SymbolicValueError( | 
|  | "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.", | 
|  | input_b, | 
|  | ) | 
|  |  | 
|  | dims_a = [ | 
|  | (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] | 
|  | for i in range(len(dims_a)) | 
|  | ] | 
|  | dims_b = [ | 
|  | (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] | 
|  | for i in range(len(dims_b)) | 
|  | ] | 
|  |  | 
|  | left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] | 
|  | left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] | 
|  |  | 
|  | new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a) | 
|  | new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b) | 
|  |  | 
|  | input_shape = g.op("Shape", new_input_a) | 
|  | left_sizes_a = symbolic_helper._slice_helper( | 
|  | g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)] | 
|  | ) | 
|  | shape_sizes = [ | 
|  | left_sizes_a, | 
|  | g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | 
|  | ] | 
|  | output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) | 
|  |  | 
|  | input_shape = g.op("Shape", output_a) | 
|  | slices = symbolic_helper._slice_helper( | 
|  | g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] | 
|  | ) | 
|  | shape_sizes = [ | 
|  | g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | 
|  | slices, | 
|  | ] | 
|  | output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes) | 
|  |  | 
|  | input_shape = g.op("Shape", new_input_b) | 
|  | left_sizes_b = symbolic_helper._slice_helper( | 
|  | g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize] | 
|  | ) | 
|  | slices = symbolic_helper._slice_helper( | 
|  | g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)] | 
|  | ) | 
|  | shape_sizes = [ | 
|  | slices, | 
|  | g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | 
|  | ] | 
|  | output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) | 
|  |  | 
|  | input_shape = g.op("Shape", output_b) | 
|  | slices = symbolic_helper._slice_helper( | 
|  | g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize] | 
|  | ) | 
|  | shape_sizes = [ | 
|  | g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), | 
|  | slices, | 
|  | ] | 
|  | output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes) | 
|  |  | 
|  | output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) | 
|  |  | 
|  | shape_sizes = [left_sizes_a, left_sizes_b] | 
|  | return opset9._reshape_from_tensor(g, output, shape_sizes) |