| """This file exports ONNX ops for opset 11.""" | 
 | from __future__ import annotations | 
 |  | 
 | import functools | 
 | import sys | 
 | import warnings | 
 | from typing import Optional, Sequence | 
 |  | 
 | import torch | 
 | from torch import _C | 
 | from torch._C import _onnx as _C_onnx | 
 | from torch.onnx import ( | 
 |     _type_utils, | 
 |     errors, | 
 |     symbolic_helper, | 
 |     symbolic_opset10 as opset10, | 
 |     symbolic_opset9 as opset9, | 
 |     utils, | 
 | ) | 
 | from torch.onnx._globals import GLOBALS | 
 | from torch.onnx._internal import _beartype, jit_utils, registration | 
 |  | 
 | # EDITING THIS FILE? READ THIS FIRST! | 
 | # see Note [Edit Symbolic Files] in README.md | 
 |  | 
 | __all__ = [ | 
 |     "add", | 
 |     "append", | 
 |     "arange", | 
 |     "argsort", | 
 |     "atleast_1d", | 
 |     "atleast_2d", | 
 |     "atleast_3d", | 
 |     "cat", | 
 |     "chunk", | 
 |     "clamp_max", | 
 |     "clamp_min", | 
 |     "clamp", | 
 |     "constant_pad_nd", | 
 |     "cumsum", | 
 |     "Delete", | 
 |     "embedding_bag", | 
 |     "embedding_renorm", | 
 |     "flatten", | 
 |     "gather", | 
 |     "hardtanh", | 
 |     "hstack", | 
 |     "im2col", | 
 |     "index_fill", | 
 |     "index", | 
 |     "index_copy", | 
 |     "index_put", | 
 |     "insert", | 
 |     "linalg_det", | 
 |     "linalg_vector_norm", | 
 |     "logdet", | 
 |     "masked_scatter", | 
 |     "masked_select", | 
 |     "mm", | 
 |     "narrow", | 
 |     "normal", | 
 |     "pad", | 
 |     "pixel_shuffle", | 
 |     "pop", | 
 |     "prim_constant_chunk", | 
 |     "reflection_pad", | 
 |     "relu6", | 
 |     "remainder", | 
 |     "replication_pad", | 
 |     "round", | 
 |     "scatter", | 
 |     "select", | 
 |     "size", | 
 |     "sort", | 
 |     "split_with_sizes", | 
 |     "split", | 
 |     "squeeze", | 
 |     "stack", | 
 |     "topk", | 
 |     "unbind", | 
 |     "unique_dim", | 
 |     "unsqueeze", | 
 |     "vstack", | 
 | ] | 
 |  | 
 | _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11) | 
 |  | 
 |  | 
 | def _apply_params(*args, **kwargs): | 
 |     """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" | 
 |  | 
 |     def _apply(fn): | 
 |         return fn(*args, **kwargs) | 
 |  | 
 |     return _apply | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::hardtanh") | 
 | @symbolic_helper.quantized_args(True) | 
 | @symbolic_helper.parse_args("v", "f", "f") | 
 | @_beartype.beartype | 
 | def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): | 
 |     scalar_type = _type_utils.JitScalarType.from_value( | 
 |         self, _type_utils.JitScalarType.FLOAT | 
 |     ) | 
 |     min_val = g.op( | 
 |         "Constant", | 
 |         value_t=torch.tensor(min_val, dtype=scalar_type.dtype()), | 
 |     ) | 
 |     max_val = g.op( | 
 |         "Constant", | 
 |         value_t=torch.tensor(max_val, dtype=scalar_type.dtype()), | 
 |     ) | 
 |     return opset9._op_with_optional_float_cast( | 
 |         g, "Clip", self, min_val, max_val, opset_before=12 | 
 |     ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::clamp") | 
 | @_beartype.beartype | 
 | def clamp(g: jit_utils.GraphContext, self, min, max): | 
 |     @_beartype.beartype | 
 |     def _cast_if_not_none(tensor, dtype): | 
 |         if tensor is not None and not symbolic_helper._is_none(tensor): | 
 |             return g.op( | 
 |                 "Cast", | 
 |                 tensor, | 
 |                 to_i=dtype.onnx_type(), | 
 |             ) | 
 |         else: | 
 |             return tensor | 
 |  | 
 |     scalar_type = _type_utils.JitScalarType.from_value( | 
 |         self, _type_utils.JitScalarType.UNDEFINED | 
 |     ) | 
 |     if scalar_type != _type_utils.JitScalarType.UNDEFINED: | 
 |         min = _cast_if_not_none(min, scalar_type) | 
 |         max = _cast_if_not_none(max, scalar_type) | 
 |  | 
 |     if symbolic_helper._is_none(min): | 
 |         return clamp_max(g, self, max) | 
 |     elif symbolic_helper._is_none(max): | 
 |         return clamp_min(g, self, min) | 
 |     else: | 
 |         if ( | 
 |             symbolic_helper._get_tensor_rank(min) == 0 | 
 |             and symbolic_helper._get_tensor_rank(max) == 0 | 
 |         ): | 
 |             return opset9._op_with_optional_float_cast( | 
 |                 g, "Clip", self, min, max, opset_before=12 | 
 |             ) | 
 |         else: | 
 |             return clamp_max(g, clamp_min(g, self, min), max) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::clamp_min") | 
 | @symbolic_helper.parse_args("v", "v") | 
 | @_beartype.beartype | 
 | def clamp_min(g: jit_utils.GraphContext, self, min): | 
 |     min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) | 
 |     if symbolic_helper._get_tensor_rank(min) == 0: | 
 |         max = opset9.unused(g) | 
 |         return opset9._op_with_optional_float_cast( | 
 |             g, "Clip", self, min, max, opset_before=12 | 
 |         ) | 
 |     else: | 
 |         return opset9._op_with_optional_float_cast(g, "Max", self, min, opset_before=12) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::clamp_max") | 
 | @symbolic_helper.parse_args("v", "v") | 
 | @_beartype.beartype | 
 | def clamp_max(g: jit_utils.GraphContext, self, max): | 
 |     max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type()) | 
 |     if symbolic_helper._get_tensor_rank(max) == 0: | 
 |         min = opset9.unused(g) | 
 |         return opset9._op_with_optional_float_cast( | 
 |             g, "Clip", self, min, max, opset_before=12 | 
 |         ) | 
 |     else: | 
 |         return opset9._op_with_optional_float_cast(g, "Min", self, max, opset_before=12) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::relu6") | 
 | @_beartype.beartype | 
 | def relu6(g: jit_utils.GraphContext, input): | 
 |     scalar_type = _type_utils.JitScalarType.from_value( | 
 |         input, _type_utils.JitScalarType.FLOAT | 
 |     ) | 
 |     min_val = g.op( | 
 |         "Constant", | 
 |         value_t=torch.tensor(0, dtype=scalar_type.dtype()), | 
 |     ) | 
 |     max_val = g.op( | 
 |         "Constant", | 
 |         value_t=torch.tensor(6, dtype=scalar_type.dtype()), | 
 |     ) | 
 |     return clamp(g, input, min_val, max_val) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::select") | 
 | # Opset 11 gather accepts negative indices | 
 | @symbolic_helper.quantized_args(True) | 
 | @symbolic_helper.parse_args("v", "i", "v") | 
 | @_beartype.beartype | 
 | def select(g: jit_utils.GraphContext, self, dim, index): | 
 |     return g.op("Gather", self, index, axis_i=dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::index_put") | 
 | @_beartype.beartype | 
 | def index_put( | 
 |     g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False | 
 | ): | 
 |     if symbolic_helper._is_packed_list(indices_list_value): | 
 |         indices_list = symbolic_helper._unpack_list(indices_list_value) | 
 |     else: | 
 |         indices_list = [indices_list_value] | 
 |     if symbolic_helper.is_caffe2_aten_fallback(): | 
 |         args = [self] + indices_list + [values, accumulate] | 
 |         return g.at("index_put", *args) | 
 |  | 
 |     accumulate = symbolic_helper._parse_arg(accumulate, "b") | 
 |  | 
 |     if len(indices_list) == 0: | 
 |         return values | 
 |  | 
 |     if len(indices_list) > 1: | 
 |         for idx_ in range(len(indices_list)): | 
 |             if symbolic_helper._is_bool(indices_list[idx_]): | 
 |                 indices_list[idx_] = g.op("NonZero", indices_list[idx_]) | 
 |         index = indices_list[0] | 
 |  | 
 |         for ind in indices_list[1:]: | 
 |             index = opset9.add(g, index, ind) | 
 |         broadcast_index_shape = g.op("Shape", index) | 
 |         indices_list = [ | 
 |             symbolic_helper._unsqueeze_helper( | 
 |                 g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] | 
 |             ) | 
 |             for ind in indices_list | 
 |         ] | 
 |         index = g.op("Concat", *indices_list, axis_i=-1) | 
 |     else: | 
 |         # Replace index_put node with masked_scatter or masked_fill | 
 |         # when inputs to the index_put node contains a single boolean input. | 
 |         # | 
 |         # index_put -> masked_fill | 
 |         #   * input index contains single tensor of Bool type (e.g.: %24 <- %23). | 
 |         #   * input value contains single element (e.g.: %18). | 
 |         # | 
 |         # Torch IR | 
 |         #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) | 
 |         #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = | 
 |         #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15) | 
 |         #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() | 
 |         #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) | 
 |         #   %24 : Tensor?[] = prim::ListConstruct(%23) | 
 |         #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = | 
 |         #                aten::index_put(%mask, %24, %18, %30) | 
 |         #   return (%25) | 
 |         # | 
 |         # | 
 |         # index_put -> masked_scatter | 
 |         #   * input index contains single tensor of Bool type (e.g.: %32 <- %31). | 
 |         #   * input value contains multiple elements (e.g.: %28). | 
 |         # | 
 |         # Torch IR | 
 |         #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) | 
 |         #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu) | 
 |         #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]() | 
 |         #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) | 
 |         #                = aten::ne(%mask, %some_const) | 
 |         #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) | 
 |         #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) | 
 |         #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() | 
 |         #   %30 : int[] = prim::Constant[value=[-1]]() | 
 |         #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) | 
 |         #   %32 : Tensor?[] = prim::ListConstruct(%31) | 
 |         #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) | 
 |         #               = aten::index_put(%mask, %32, %28, %38) | 
 |         #   return (%33) | 
 |         index = indices_list[0] | 
 |         bool_inp = index | 
 |         if symbolic_helper._is_bool(bool_inp): | 
 |             rank = symbolic_helper._get_tensor_rank(values) | 
 |             if rank is not None and rank == 0: | 
 |                 return opset9.masked_fill(g, self, bool_inp, values) | 
 |             mask_rank = symbolic_helper._get_tensor_rank(bool_inp) | 
 |             self_rank = symbolic_helper._get_tensor_rank(self) | 
 |             if ( | 
 |                 mask_rank is not None | 
 |                 and self_rank is not None | 
 |                 and self_rank > mask_rank | 
 |             ): | 
 |                 # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'. | 
 |                 bool_inp = symbolic_helper._unsqueeze_helper( | 
 |                     g, bool_inp, list(range(mask_rank, self_rank)) | 
 |                 ) | 
 |             return masked_scatter(g, self, bool_inp, values) | 
 |         broadcast_index_shape = g.op("Shape", index) | 
 |         index = symbolic_helper._unsqueeze_helper(g, index, [-1]) | 
 |     sub_data_shape = symbolic_helper._slice_helper( | 
 |         g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] | 
 |     ) | 
 |     values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) | 
 |     # Check if values is a singular value and expand accordingly | 
 |     rank = symbolic_helper._get_tensor_rank(values) | 
 |     if rank is not None and rank == 0: | 
 |         values = opset9.expand(g, values, values_shape, None) | 
 |     values = symbolic_helper._reshape_helper(g, values, values_shape) | 
 |  | 
 |     self_scalar_type = _type_utils.JitScalarType.from_value( | 
 |         self, _type_utils.JitScalarType.UNDEFINED | 
 |     ) | 
 |     if self_scalar_type != _type_utils.JitScalarType.UNDEFINED: | 
 |         values_scalar_type = _type_utils.JitScalarType.from_value( | 
 |             values, _type_utils.JitScalarType.UNDEFINED | 
 |         ) | 
 |         if self_scalar_type != values_scalar_type: | 
 |             values = g.op("Cast", values, to_i=self_scalar_type.onnx_type()) | 
 |     elif accumulate: | 
 |         raise errors.SymbolicValueError("self does not have a valid scalar type.", self) | 
 |  | 
 |     if accumulate: | 
 |         zeros = g.op( | 
 |             "ConstantOfShape", | 
 |             g.op("Shape", self), | 
 |             value_t=torch.tensor([0], dtype=self_scalar_type.dtype()), | 
 |         ) | 
 |         result = g.op("ScatterND", zeros, index, values) | 
 |         result = add(g, self, result) | 
 |     else: | 
 |         result = g.op("ScatterND", self, index, values) | 
 |  | 
 |     return result | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::pixel_shuffle") | 
 | @symbolic_helper.parse_args("v", "i") | 
 | @_beartype.beartype | 
 | def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): | 
 |     rank = symbolic_helper._get_tensor_rank(self) | 
 |     if rank is not None and rank != 4: | 
 |         return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input") | 
 |     return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD") | 
 |  | 
 |  | 
 | @_onnx_symbolic( | 
 |     "aten::upsample_nearest1d", | 
 |     decorate=[_apply_params("upsample_nearest1d", 3, "nearest")], | 
 | ) | 
 | @_onnx_symbolic( | 
 |     "aten::upsample_nearest2d", | 
 |     decorate=[_apply_params("upsample_nearest2d", 4, "nearest")], | 
 | ) | 
 | @_onnx_symbolic( | 
 |     "aten::upsample_nearest3d", | 
 |     decorate=[_apply_params("upsample_nearest3d", 5, "nearest")], | 
 | ) | 
 | @_onnx_symbolic( | 
 |     "aten::upsample_linear1d", | 
 |     decorate=[_apply_params("upsample_linear1d", 3, "linear")], | 
 | ) | 
 | @_onnx_symbolic( | 
 |     "aten::upsample_bilinear2d", | 
 |     decorate=[_apply_params("upsample_bilinear2d", 4, "linear")], | 
 | ) | 
 | @_onnx_symbolic( | 
 |     "aten::upsample_trilinear3d", | 
 |     decorate=[_apply_params("upsample_trilinear3d", 5, "linear")], | 
 | ) | 
 | @_onnx_symbolic( | 
 |     "aten::upsample_bicubic2d", | 
 |     decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")], | 
 | ) | 
 | @_beartype.beartype | 
 | def _interpolate(name: str, dim: int, interpolate_mode: str): | 
 |     return symbolic_helper._interpolate_helper(name, dim, interpolate_mode) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::__interpolate") | 
 | @symbolic_helper.quantized_args(True, False, False, False, False, False, False) | 
 | @_beartype.beartype | 
 | def __interpolate( | 
 |     g: jit_utils.GraphContext, | 
 |     input, | 
 |     size, | 
 |     scale_factor, | 
 |     mode, | 
 |     align_corners, | 
 |     recompute_scale_factor, | 
 |     antialias, | 
 | ): | 
 |     return symbolic_helper.__interpolate_helper( | 
 |         g, input, size, scale_factor, mode, align_corners, recompute_scale_factor | 
 |     ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::gather") | 
 | @symbolic_helper.parse_args("v", "i", "v", "v") | 
 | @_beartype.beartype | 
 | def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): | 
 |     if symbolic_helper._maybe_get_const(sparse_grad, "i"): | 
 |         return symbolic_helper._unimplemented("gather", "sparse_grad == True") | 
 |     if symbolic_helper.is_caffe2_aten_fallback(): | 
 |         return g.at("gather", self, dim, index, sparse_grad) | 
 |     return g.op("GatherElements", self, index, axis_i=dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::scatter") | 
 | @symbolic_helper.parse_args("v", "i", "v", "v") | 
 | @_beartype.beartype | 
 | def scatter(g: jit_utils.GraphContext, self, dim, index, src): | 
 |     if symbolic_helper.is_caffe2_aten_fallback(): | 
 |         return g.at("scatter", self, dim, index, src, overload_name="src") | 
 |     src_type = _type_utils.JitScalarType.from_value(src) | 
 |     src = symbolic_helper._maybe_get_scalar(src) | 
 |     if symbolic_helper._is_value(src): | 
 |         return g.op("ScatterElements", self, index, src, axis_i=dim) | 
 |     else: | 
 |         # Check if scalar "src" has same type as self (PyTorch allows different | 
 |         # type for scalar src (but not when src is tensor)). If not, insert Cast node. | 
 |         if _type_utils.JitScalarType.from_value(self) != src_type: | 
 |             src = g.op( | 
 |                 "Cast", | 
 |                 src, | 
 |                 to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), | 
 |             ) | 
 |         return g.op( | 
 |             "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim | 
 |         ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::cumsum") | 
 | @symbolic_helper.parse_args("v", "i", "none") | 
 | @_beartype.beartype | 
 | def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None): | 
 |     dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int)) | 
 |     if dtype and dtype.node().kind() != "prim::Constant": | 
 |         parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") | 
 |         cast = g.op( | 
 |             "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() | 
 |         ) | 
 |     else: | 
 |         cast = self | 
 |     csum = g.op("CumSum", cast, dim_tensor) | 
 |     return csum | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::masked_select") | 
 | @_beartype.beartype | 
 | def masked_select(g: jit_utils.GraphContext, self, mask): | 
 |     index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) | 
 |     return g.op("GatherND", self, index) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::masked_scatter") | 
 | @_beartype.beartype | 
 | def masked_scatter(g: jit_utils.GraphContext, self, mask, source): | 
 |     index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) | 
 |     # NOTE: source can have more elements than needed. | 
 |     # It could also have arbitrary shape. | 
 |     # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. | 
 |     source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) | 
 |     source = symbolic_helper._slice_helper( | 
 |         g, | 
 |         source, | 
 |         axes=torch.LongTensor([0]), | 
 |         starts=torch.LongTensor([0]), | 
 |         ends=opset9.size(g, index, torch.LongTensor([0])), | 
 |     ) | 
 |     return g.op("ScatterND", self, index, source) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::len") | 
 | @_beartype.beartype | 
 | def _len(g: jit_utils.GraphContext, self): | 
 |     if ( | 
 |         symbolic_helper._is_tensor_list(self) | 
 |         or self.node().kind() == "onnx::SplitToSequence" | 
 |     ): | 
 |         return g.op("SequenceLength", self) | 
 |     sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) | 
 |     return symbolic_helper._squeeze_helper(g, sz_0, [0]) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::__getitem_") | 
 | @_beartype.beartype | 
 | def __getitem_(g: jit_utils.GraphContext, self, i): | 
 |     if symbolic_helper._is_tensor_list(self): | 
 |         # SequenceAt requires that the input be a List of Tensors | 
 |         return g.op("SequenceAt", self, i) | 
 |     else: | 
 |         from torch.onnx.symbolic_opset9 import __getitem_ as getitem | 
 |  | 
 |         return getitem(g, self, i) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::_set_item") | 
 | @_beartype.beartype | 
 | def _set_item(g: jit_utils.GraphContext, tensor_list, i, v): | 
 |     tensor_list = g.op("SequenceErase", tensor_list, i) | 
 |     return g.op("SequenceInsert", tensor_list, v, i) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::append") | 
 | @_beartype.beartype | 
 | def append(g: jit_utils.GraphContext, self, tensor): | 
 |     return g.op("SequenceInsert", self, tensor) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::add") | 
 | @_beartype.beartype | 
 | def add(g: jit_utils.GraphContext, self, other, alpha=None): | 
 |     if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): | 
 |         tensor_list_node = other.node() | 
 |         if tensor_list_node.kind() != "prim::ListConstruct": | 
 |             return symbolic_helper._unimplemented( | 
 |                 "add", "does not support adding dynamic tensor list to another" | 
 |             ) | 
 |         tensors = symbolic_helper._unpack_list(other) | 
 |         l = self | 
 |         for t in tensors: | 
 |             l = g.op("SequenceInsert", l, t) | 
 |         return l | 
 |  | 
 |     return opset9.add(g, self, other, alpha) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::insert") | 
 | @_beartype.beartype | 
 | def insert(g: jit_utils.GraphContext, self, pos, tensor): | 
 |     return g.op("SequenceInsert", self, tensor, pos) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::pop") | 
 | @_beartype.beartype | 
 | def pop(g: jit_utils.GraphContext, tensor_list, dim): | 
 |     return g.op("SequenceErase", tensor_list, dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::Delete") | 
 | @_beartype.beartype | 
 | def Delete(g: jit_utils.GraphContext, tensor_list, dim): | 
 |     return g.op("SequenceErase", tensor_list, dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::cat") | 
 | @symbolic_helper.quantized_args(True) | 
 | @_beartype.beartype | 
 | def cat(g: jit_utils.GraphContext, tensor_list, dim): | 
 |     if symbolic_helper._is_packed_list(tensor_list): | 
 |         return opset9.cat(g, tensor_list, dim) | 
 |     else: | 
 |         dim = symbolic_helper._get_const(dim, "i", "dim") | 
 |         return g.op("ConcatFromSequence", tensor_list, axis_i=dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::stack") | 
 | @_beartype.beartype | 
 | def stack(g: jit_utils.GraphContext, tensor_list, dim): | 
 |     if symbolic_helper._is_packed_list(tensor_list): | 
 |         return opset9.stack(g, tensor_list, dim) | 
 |     else: | 
 |         dim = symbolic_helper._get_const(dim, "i", "dim") | 
 |         return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::_unique2") | 
 | @symbolic_helper.parse_args("v", "i", "i", "i") | 
 | @_beartype.beartype | 
 | def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): | 
 |     u, indices, inverse_indices, counts = g.op( | 
 |         "Unique", self, sorted_i=sorted, outputs=4 | 
 |     ) | 
 |     return u, inverse_indices, counts | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::unique_dim") | 
 | @symbolic_helper.parse_args("v", "i", "i", "i", "i") | 
 | @_beartype.beartype | 
 | def unique_dim( | 
 |     g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts | 
 | ): | 
 |     u, indices, inverse_indices, counts = g.op( | 
 |         "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 | 
 |     ) | 
 |     return u, inverse_indices, counts | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::topk") | 
 | @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") | 
 | @_beartype.beartype | 
 | def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): | 
 |     return symbolic_helper._topk_helper( | 
 |         g, self, k, dim, largest=largest, sorted=sorted, out=out | 
 |     ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::sort") | 
 | @symbolic_helper.parse_args("v", "i", "i", "none") | 
 | @_beartype.beartype | 
 | def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): | 
 |     return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::argsort") | 
 | @symbolic_helper.parse_args("v", "i", "i", "none") | 
 | @_beartype.beartype | 
 | def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None): | 
 |     _, indices = symbolic_helper._sort_helper( | 
 |         g, self, dim, decending=decending, out=out | 
 |     ) | 
 |     return indices | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::round") | 
 | @symbolic_helper.parse_args("v", "i") | 
 | @_beartype.beartype | 
 | def round(g: jit_utils.GraphContext, self, decimals=0): | 
 |     if not symbolic_helper._is_fp(self): | 
 |         return self | 
 |     if decimals == 0: | 
 |         return g.op("Round", self) | 
 |     mul = g.op("Mul", self, g.op("Constant", value_t=torch.tensor(pow(10, decimals)))) | 
 |     round = g.op("Round", mul) | 
 |     return g.op( | 
 |         "Mul", round, g.op("Constant", value_t=torch.tensor(pow(10, -1 * decimals))) | 
 |     ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::remainder") | 
 | @_beartype.beartype | 
 | def remainder(g: jit_utils.GraphContext, input, other): | 
 |     if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other): | 
 |         return opset9.remainder(g, input, other) | 
 |     return g.op("Mod", input, other, fmod_i=0) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::split") | 
 | @symbolic_helper.parse_args("v", "v", "i", "i") | 
 | @_beartype.beartype | 
 | def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): | 
 |     if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): | 
 |         split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) | 
 |         if _outputs is None: | 
 |             return split_out | 
 |         # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. | 
 |         if ( | 
 |             symbolic_helper._is_packed_list(split_size_or_sizes) | 
 |             and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs | 
 |         ): | 
 |             split_sizes = [ | 
 |                 symbolic_helper._unsqueeze_helper(g, v, [0]) | 
 |                 for v in symbolic_helper._unpack_list(split_size_or_sizes) | 
 |             ] | 
 |             start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) | 
 |             axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) | 
 |             res = [] | 
 |             for i in range(_outputs): | 
 |                 end = g.op( | 
 |                     "Add", start, split_sizes[i] | 
 |                 )  # split_sizes is a list of same length as _outputs | 
 |                 res.append(g.op("Slice", self, start, end, axis)) | 
 |                 start = end | 
 |             return res | 
 |         return [ | 
 |             g.op( | 
 |                 "SequenceAt", | 
 |                 split_out, | 
 |                 g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), | 
 |             ) | 
 |             for i in range(_outputs) | 
 |         ] | 
 |     else: | 
 |         return opset9.split(g, self, split_size_or_sizes, dim, _outputs) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::split_with_sizes") | 
 | @symbolic_helper.parse_args("v", "v", "i", "i") | 
 | @_beartype.beartype | 
 | def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): | 
 |     return split(g, self, split_sizes, dim, _outputs) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::unbind") | 
 | @symbolic_helper.parse_args("v", "i", "i") | 
 | @_beartype.beartype | 
 | def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): | 
 |     if _outputs is None: | 
 |         return g.op( | 
 |             "SplitToSequence", | 
 |             self, | 
 |             g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)), | 
 |             axis_i=dim, | 
 |             keepdims_i=0, | 
 |         ) | 
 |     else: | 
 |         return opset9.unbind(g, self, dim, _outputs) | 
 |  | 
 |  | 
 | @_beartype.beartype | 
 | def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad): | 
 |     """Generate paddings in ONNX order based on pad in pytorch. | 
 |  | 
 |     Args: | 
 |         input: the input tensor. | 
 |         pad: the paddings in pytorch. | 
 |             The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, | 
 |             where m is in range [0, n]. | 
 |     """ | 
 |     if ( | 
 |         not symbolic_helper._is_packed_list(pad) | 
 |         and symbolic_helper._is_list(pad) | 
 |         and symbolic_helper._is_scalar_list(pad) | 
 |     ): | 
 |         pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) | 
 |     # The desired order of paddings is | 
 |     # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. | 
 |     # n is the dimension of input. | 
 |     # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning | 
 |     pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) | 
 |     # Set extension = [0] * (dim * 2 - len(pad)) | 
 |     rank = symbolic_helper._get_tensor_rank(input) | 
 |     if rank is None: | 
 |         rank = g.op("Size", g.op("Shape", input)) | 
 |     else: | 
 |         rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) | 
 |     extension = g.op( | 
 |         "Sub", | 
 |         g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), | 
 |         pad_len, | 
 |     ) | 
 |     # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] | 
 |     # Currently ONNX only supports int64 type for Pad | 
 |     pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64) | 
 |     paddings = g.op( | 
 |         "Concat", | 
 |         pad, | 
 |         g.op( | 
 |             "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) | 
 |         ), | 
 |         axis_i=0, | 
 |     ) | 
 |     # Reshape and reverse order and collate first beginnings and then ends | 
 |     # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], | 
 |     #               [..., 0, dim_n-1_end, dim_n_end]] | 
 |     # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] | 
 |     paddings = symbolic_helper._reshape_helper( | 
 |         g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) | 
 |     ) | 
 |     paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) | 
 |     paddings = symbolic_helper._reshape_helper( | 
 |         g, paddings, g.op("Constant", value_t=torch.tensor([-1])) | 
 |     ) | 
 |     padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64) | 
 |     return padding_c | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::constant_pad_nd") | 
 | @_beartype.beartype | 
 | def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None): | 
 |     mode = "constant" | 
 |     value = symbolic_helper._maybe_get_scalar(value) | 
 |     value = symbolic_helper._if_scalar_type_as(value, input) | 
 |     pad = _prepare_onnx_paddings(g, input, padding) | 
 |     return g.op("Pad", input, pad, value, mode_s=mode) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::reflection_pad1d") | 
 | @_onnx_symbolic("aten::reflection_pad2d") | 
 | @_onnx_symbolic("aten::reflection_pad3d") | 
 | @_beartype.beartype | 
 | def reflection_pad(g: jit_utils.GraphContext, input, padding): | 
 |     mode = "reflect" | 
 |     paddings = _prepare_onnx_paddings(g, input, padding) | 
 |     return g.op("Pad", input, paddings, mode_s=mode) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::replication_pad1d") | 
 | @_onnx_symbolic("aten::replication_pad2d") | 
 | @_onnx_symbolic("aten::replication_pad3d") | 
 | @_beartype.beartype | 
 | def replication_pad(g: jit_utils.GraphContext, input, padding): | 
 |     mode = "edge" | 
 |     paddings = _prepare_onnx_paddings(g, input, padding) | 
 |     return g.op("Pad", input, paddings, mode_s=mode) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::pad") | 
 | @_beartype.beartype | 
 | def pad( | 
 |     g: jit_utils.GraphContext, | 
 |     input: _C.Value, | 
 |     pad: _C.Value, | 
 |     mode: _C.Value, | 
 |     value: _C.Value, | 
 | ): | 
 |     mode = symbolic_helper._parse_arg(mode, "s") | 
 |     if mode == "replicate": | 
 |         return replication_pad(g, input, pad) | 
 |     elif mode == "reflect": | 
 |         return reflection_pad(g, input, pad) | 
 |     elif mode == "constant": | 
 |         return constant_pad_nd(g, input, pad, value) | 
 |     elif mode == "circular": | 
 |         return opset9._pad_circular(g, input, pad) | 
 |     else: | 
 |         raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::linalg_det") | 
 | @_beartype.beartype | 
 | def linalg_det(g: jit_utils.GraphContext, self): | 
 |     return g.op("Det", self) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::logdet") | 
 | @_beartype.beartype | 
 | def logdet(g: jit_utils.GraphContext, input): | 
 |     return opset9.log(g, linalg_det(g, input)) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::arange") | 
 | @_beartype.beartype | 
 | def arange(g: jit_utils.GraphContext, *args): | 
 |     def _get_arange_dtype(dtype): | 
 |         dtype = symbolic_helper._maybe_get_const(dtype, "i") | 
 |         return dtype | 
 |  | 
 |     if len(args) == 2 and all(isinstance(val, int) for val in args): | 
 |         # aten::arange(Scalar start, Scalar end) | 
 |         dtype = torch.int64 | 
 |         # Start index. | 
 |         start = g.op( | 
 |             "Constant", | 
 |             value_t=torch.tensor(args[0], dtype=dtype), | 
 |         ) | 
 |         # End (exclusive) index. | 
 |         end = g.op( | 
 |             "Constant", | 
 |             value_t=torch.tensor(args[1], dtype=dtype), | 
 |         ) | 
 |         # Step size from start to end indexes. | 
 |         delta_default = g.op( | 
 |             "Constant", | 
 |             value_t=torch.tensor(1, dtype=dtype), | 
 |         ) | 
 |         return g.op("Range", start, end, delta_default) | 
 |     elif len(args) == 2 or len(args) == 5: | 
 |         if len(args) == 2: | 
 |             # aten::arange(Scalar end, Tensor out) | 
 |             dtype = None | 
 |         else: | 
 |             # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) | 
 |             dtype = _get_arange_dtype(args[1]) | 
 |         type_, end, start, step = symbolic_helper._arange_cast_helper( | 
 |             g, end=args[0], dtype=dtype | 
 |         ) | 
 |         start_default = g.op( | 
 |             "Constant", | 
 |             value_t=torch.tensor(0, dtype=type_.dtype()), | 
 |         ) | 
 |         delta_default = g.op( | 
 |             "Constant", | 
 |             value_t=torch.tensor(1, dtype=type_.dtype()), | 
 |         ) | 
 |         return g.op("Range", start_default, end, delta_default) | 
 |     elif len(args) == 4 or len(args) == 7: | 
 |         if len(args) == 4: | 
 |             # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) | 
 |             dtype = None | 
 |         else: | 
 |             # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) | 
 |             dtype = _get_arange_dtype(args[3]) | 
 |         _, end, start, step = symbolic_helper._arange_cast_helper( | 
 |             g, start=args[0], end=args[1], step=args[2], dtype=dtype | 
 |         ) | 
 |         return g.op("Range", start, end, step) | 
 |     elif len(args) == 6: | 
 |         # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) | 
 |         dtype = _get_arange_dtype(args[2]) | 
 |         type_, end, start, step = symbolic_helper._arange_cast_helper( | 
 |             g, start=args[0], end=args[1], dtype=dtype | 
 |         ) | 
 |         delta_default = g.op( | 
 |             "Constant", | 
 |             value_t=torch.tensor(1, dtype=type_.dtype()), | 
 |         ) | 
 |         return g.op("Range", start, end, delta_default) | 
 |     else: | 
 |         return symbolic_helper._unimplemented( | 
 |             "aten::arange", f"with {len(args)} arguments" | 
 |         ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::_dim_arange") | 
 | @symbolic_helper.parse_args("v", "i") | 
 | @_beartype.beartype | 
 | def _dim_arange(g: jit_utils.GraphContext, like, dim): | 
 |     like_shape = g.op("Shape", like) | 
 |     stop = g.op( | 
 |         "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 | 
 |     ) | 
 |     if symbolic_helper.is_caffe2_aten_fallback(): | 
 |         return g.op("_caffe2::Range", stop) | 
 |     return arange(g, stop, 4, None, None, None) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::size") | 
 | @symbolic_helper.quantized_args(True, quantize_output=False) | 
 | @_beartype.beartype | 
 | def size(g: jit_utils.GraphContext, self, dim=None): | 
 |     if dim is None: | 
 |         return g.op("Shape", self) | 
 |     return symbolic_helper._size_helper(g, self, dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::squeeze") | 
 | @_beartype.beartype | 
 | def squeeze(g: jit_utils.GraphContext, self, dim=None): | 
 |     if dim is None: | 
 |         return g.op("Squeeze", self) | 
 |  | 
 |     # dim as a tensor | 
 |     if not symbolic_helper._is_constant(dim): | 
 |         return symbolic_helper._squeeze_helper(g, self, [dim]) | 
 |  | 
 |     dim = symbolic_helper._get_const(dim, "i", "dim") | 
 |  | 
 |     input_rank = symbolic_helper._get_tensor_rank(self) | 
 |     adjusted_dim = dim | 
 |     if input_rank is not None and dim < 0: | 
 |         adjusted_dim += input_rank | 
 |     dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim) | 
 |     if (dim < 0 and input_rank is None) or dim_size is None: | 
 |         # If onnx shape inference is not on, export always as dynamic. | 
 |         # Because we cannot tell if observed static shape is also static at runtime. | 
 |         # create "cond" node (condition is shape[i]==1) | 
 |         dim_constant = g.op("Constant", value_t=torch.tensor([dim])) | 
 |         size = symbolic_helper._size_helper(g, self, dim_constant) | 
 |         const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64)) | 
 |         cond = g.op("Equal", size, const_one) | 
 |         # create the "If" node and add the "then" and "else" blocks to it. | 
 |         if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks( | 
 |             g, "If", cond, n_blocks=2 | 
 |         ) | 
 |         squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim]) | 
 |         utils._add_output_to_block(if_context.block, squeeze_) | 
 |         identity_ = else_context.op("Identity", self) | 
 |         utils._add_output_to_block(else_context.block, identity_) | 
 |         return if_op | 
 |  | 
 |     # For static input shape | 
 |     dim = adjusted_dim | 
 |     if dim_size > 1: | 
 |         warnings.warn( | 
 |             "This model contains a squeeze operation on dimension " | 
 |             + str(dim) | 
 |             + ". The size of " | 
 |             + "this dimension in the given input is " | 
 |             + str(dim_size) | 
 |             + ". The model will " | 
 |             + "be exported without the squeeze node. If the model is intended to be used with dynamic " | 
 |             + "input shapes, please export with dynamic_axes argument." | 
 |         ) | 
 |         return self | 
 |     return symbolic_helper._squeeze_helper(g, self, [dim]) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::unsqueeze") | 
 | @_beartype.beartype | 
 | def unsqueeze(g: jit_utils.GraphContext, self, dim): | 
 |     if symbolic_helper._is_constant(dim): | 
 |         dim = symbolic_helper._get_const(dim, "i", "dim") | 
 |  | 
 |     return symbolic_helper._unsqueeze_helper(g, self, [dim]) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::mm") | 
 | @_beartype.beartype | 
 | def mm(g: jit_utils.GraphContext, self, other): | 
 |     return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::index") | 
 | @_beartype.beartype | 
 | def index(g: jit_utils.GraphContext, self, index): | 
 |     if symbolic_helper.is_caffe2_aten_fallback(): | 
 |         return g.at("index", self, index, overload_name="Tensor") | 
 |  | 
 |     if symbolic_helper._is_packed_list(index): | 
 |         indices = symbolic_helper._unpack_list(index) | 
 |     else: | 
 |         indices = [index] | 
 |  | 
 |     # Handle single mask index. | 
 |     if len(indices) == 1: | 
 |         index = indices[0] | 
 |         if not symbolic_helper._is_none(index) and ( | 
 |             symbolic_helper._is_bool(index) | 
 |             or _type_utils.JitScalarType.from_value(index) | 
 |             == _type_utils.JitScalarType.UINT8 | 
 |         ): | 
 |             index = opset9.nonzero(g, index) | 
 |             return g.op("GatherND", self, index) | 
 |     return opset9.index(g, self, index) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::index_fill") | 
 | @_beartype.beartype | 
 | def index_fill(g: jit_utils.GraphContext, self, dim, index, value): | 
 |     dim_value = symbolic_helper._parse_arg(dim, "i") | 
 |     if symbolic_helper.is_caffe2_aten_fallback(): | 
 |         return g.at( | 
 |             "index_fill", | 
 |             self, | 
 |             index, | 
 |             value, | 
 |             overload_name="int_Scalar", | 
 |             dim_i=dim_value, | 
 |         ) | 
 |  | 
 |     expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( | 
 |         g, self, dim, index | 
 |     ) | 
 |     value = symbolic_helper._maybe_get_scalar(value) | 
 |     value = symbolic_helper._if_scalar_type_as(value, self) | 
 |     expanded_value = opset9.expand(g, value, expanded_index_shape, None) | 
 |     return scatter(g, self, dim, expanded_index, expanded_value) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::index_copy") | 
 | @_beartype.beartype | 
 | def index_copy(g: jit_utils.GraphContext, self, dim, index, source): | 
 |     dim_value = symbolic_helper._parse_arg(dim, "i") | 
 |     if symbolic_helper.is_caffe2_aten_fallback(): | 
 |         return g.at("index_copy", self, index, source, dim_i=dim_value) | 
 |     expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( | 
 |         g, self, dim, index | 
 |     ) | 
 |     return scatter(g, self, dim, expanded_index, source) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::__rshift_") | 
 | @_beartype.beartype | 
 | def __rshift_(g: jit_utils.GraphContext, self, other): | 
 |     # make sure to cast other to self's type | 
 |     # (when self is long, make sure that other is not float) | 
 |     if _type_utils.JitScalarType.from_value( | 
 |         other, _type_utils.JitScalarType.UNDEFINED | 
 |     ) != _type_utils.JitScalarType.from_value(self): | 
 |         other = g.op( | 
 |             "Cast", | 
 |             other, | 
 |             to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), | 
 |         ) | 
 |  | 
 |     if ( | 
 |         _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) | 
 |         == _type_utils.JitScalarType.UINT8 | 
 |     ): | 
 |         return g.op("BitShift", self, other, direction_s="RIGHT") | 
 |  | 
 |     two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) | 
 |     # exponent (same type as self) has to be float or double in onnx::Pow | 
 |     if not symbolic_helper._is_fp(self): | 
 |         other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) | 
 |     two_pow = g.op("Pow", two, other) | 
 |     two_pow = g.op( | 
 |         "Cast", | 
 |         two_pow, | 
 |         to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), | 
 |     ) | 
 |     rshift = g.op("Div", self, two_pow) | 
 |     return rshift | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::__lshift_") | 
 | @_beartype.beartype | 
 | def __lshift_(g: jit_utils.GraphContext, self, other): | 
 |     # make sure to cast other to self's type | 
 |     # (when self is long, make sure that other is not float) | 
 |     if _type_utils.JitScalarType.from_value( | 
 |         other, _type_utils.JitScalarType.UNDEFINED | 
 |     ) != _type_utils.JitScalarType.from_value(self): | 
 |         other = g.op( | 
 |             "Cast", | 
 |             other, | 
 |             to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), | 
 |         ) | 
 |  | 
 |     if ( | 
 |         _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED) | 
 |         == _type_utils.JitScalarType.UINT8 | 
 |     ): | 
 |         return g.op("BitShift", self, other, direction_s="LEFT") | 
 |  | 
 |     two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) | 
 |     # exponent (same type as self) has to be float or double in onnx::Pow | 
 |     if not symbolic_helper._is_fp(self): | 
 |         other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) | 
 |     two_pow = g.op("Pow", two, other) | 
 |     two_pow = g.op( | 
 |         "Cast", | 
 |         two_pow, | 
 |         to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), | 
 |     ) | 
 |     lshift = g.op("Mul", self, two_pow) | 
 |     return lshift | 
 |  | 
 |  | 
 | @_beartype.beartype | 
 | def _get_im2col_indices_along_dim( | 
 |     g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d | 
 | ): | 
 |     # Input is always 4-D (N, C, H, W) | 
 |     # Calculate indices of sliding blocks along spatial dimension | 
 |     # Slide kernel over input each dim d: | 
 |     # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) | 
 |     # with steps = stride | 
 |  | 
 |     blocks_d = g.op( | 
 |         "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2)) | 
 |     ) | 
 |     blocks_d = g.op( | 
 |         "Sub", | 
 |         blocks_d, | 
 |         g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))), | 
 |     ) | 
 |  | 
 |     # Stride kernel over input and find starting indices along dim d | 
 |     blocks_d_indices = g.op( | 
 |         "Range", | 
 |         g.op("Constant", value_t=torch.tensor(0)), | 
 |         blocks_d, | 
 |         g.op("Constant", value_t=torch.tensor(stride_d)), | 
 |     ) | 
 |  | 
 |     # Apply dilation on kernel and find its indices along dim d | 
 |     kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d) | 
 |     kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0)) | 
 |  | 
 |     # Broadcast and add kernel staring positions (indices) with | 
 |     # kernel_grid along dim d, to get block indices along dim d | 
 |     blocks_d_indices = symbolic_helper._unsqueeze_helper( | 
 |         g, blocks_d_indices, [0] | 
 |     )  # Reshape to [1, -1] | 
 |     kernel_mask = symbolic_helper._reshape_helper( | 
 |         g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1])) | 
 |     ) | 
 |     block_mask = g.op("Add", blocks_d_indices, kernel_mask) | 
 |  | 
 |     return block_mask | 
 |  | 
 |  | 
 | @_beartype.beartype | 
 | def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w): | 
 |     # Input is always 4-D tensor (N, C, H, W) | 
 |     # Padding tensor has the following format: (padding_h, padding_w) | 
 |     # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) | 
 |     pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2)) | 
 |     return g.op("Pad", input, pad) | 
 |  | 
 |  | 
 | @_beartype.beartype | 
 | def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w): | 
 |     batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0))) | 
 |     channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1))) | 
 |     channel_unfolded = g.op( | 
 |         "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)) | 
 |     ) | 
 |  | 
 |     return g.op( | 
 |         "Concat", | 
 |         symbolic_helper._unsqueeze_helper(g, batch_dim, [0]), | 
 |         symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]), | 
 |         g.op("Constant", value_t=torch.tensor([-1])), | 
 |         axis_i=0, | 
 |     ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::im2col") | 
 | @symbolic_helper.parse_args("v", "is", "is", "is", "is") | 
 | @_beartype.beartype | 
 | def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride): | 
 |     # Input is always 4-D tensor (N, C, H, W) | 
 |     # All other args are int[2] | 
 |  | 
 |     input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2))) | 
 |     input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3))) | 
 |  | 
 |     stride_h, stride_w = stride[0], stride[1] | 
 |     padding_h, padding_w = padding[0], padding[1] | 
 |     dilation_h, dilation_w = dilation[0], dilation[1] | 
 |     kernel_h, kernel_w = kernel_size[0], kernel_size[1] | 
 |  | 
 |     blocks_row_indices = _get_im2col_indices_along_dim( | 
 |         g, input_h, kernel_h, dilation_h, padding_h, stride_h | 
 |     ) | 
 |     blocks_col_indices = _get_im2col_indices_along_dim( | 
 |         g, input_w, kernel_w, dilation_w, padding_w, stride_w | 
 |     ) | 
 |  | 
 |     output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w) | 
 |     padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w) | 
 |  | 
 |     # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 | 
 |     # [[[[1., 2., 3.,], | 
 |     #    [4., 5., 6.,], | 
 |     #    [7., 8., 9.,]]]] | 
 |     # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: | 
 |     # [[[[[1., 2., 3.], | 
 |     #     [4., 5., 6.]], | 
 |     #    [[4., 5., 6.], | 
 |     #     [7., 8., 9.]]]]] | 
 |     # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: | 
 |     # [[[[[[1., 2.], | 
 |     #      [4., 5.]], | 
 |     #     [[2., 3.], | 
 |     #      [5., 6]]], | 
 |     #    [[[4., 5.], | 
 |     #      [7., 8.]], | 
 |     #     [[5., 6.], | 
 |     #      [8., 9.]]]]]] | 
 |     # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: | 
 |     #  [[[1., 2., 4., 5.], | 
 |     #    [2., 3., 5., 6.], | 
 |     #    [4., 5., 7., 8.], | 
 |     #    [5., 6., 8., 9.]]] | 
 |     output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2) | 
 |     output = g.op("Gather", output, blocks_col_indices, axis_i=4) | 
 |     output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5]) | 
 |     return symbolic_helper._reshape_helper(g, output, output_shape) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::narrow") | 
 | @_beartype.beartype | 
 | def narrow(g: jit_utils.GraphContext, input, dim, start, length): | 
 |     end = g.op("Add", start, length) | 
 |     return symbolic_helper._slice_helper(g, input, axes=dim, starts=start, ends=end) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::flatten") | 
 | @symbolic_helper.quantized_args(True, False, False) | 
 | @symbolic_helper.parse_args("v", "i", "i") | 
 | @_beartype.beartype | 
 | def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): | 
 |     dim = symbolic_helper._get_tensor_rank(input) | 
 |     if dim == 1: | 
 |         return input | 
 |     # use ONNX's Flatten operator for cases where the output shape is 2D | 
 |     if start_dim == 1: | 
 |         if end_dim == -1 or (dim is not None and end_dim == dim - 1): | 
 |             return g.op("Flatten", input, axis_i=start_dim) | 
 |     elif start_dim == 0: | 
 |         if end_dim == -2 or (dim is not None and end_dim == dim - 2): | 
 |             return g.op("Flatten", input, axis_i=end_dim + 1) | 
 |     if dim is None: | 
 |         return symbolic_helper._unimplemented( | 
 |             "dim", | 
 |             "ONNX and PyTorch use different strategies to split the input. " | 
 |             "Input rank must be known at export time.", | 
 |         ) | 
 |     # if end_dim is negative add dim | 
 |     if end_dim < 0: | 
 |         end_dim = dim + end_dim | 
 |  | 
 |     return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::linalg_vector_norm") | 
 | @symbolic_helper.parse_args("v", "f", "is", "b", "v") | 
 | @_beartype.beartype | 
 | def linalg_vector_norm( | 
 |     g: jit_utils.GraphContext, | 
 |     self, | 
 |     ord, | 
 |     dim: Optional[Sequence[int]], | 
 |     keepdim: bool, | 
 |     dtype, | 
 | ): | 
 |     if ord == 0: | 
 |         if dim is None: | 
 |             self = symbolic_helper._reshape_helper( | 
 |                 g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) | 
 |             ) | 
 |             keepdim = False | 
 |  | 
 |         cond_op = g.op( | 
 |             "Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))) | 
 |         ) | 
 |         cond_op = g.op( | 
 |             "Cast", | 
 |             cond_op, | 
 |             to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), | 
 |         ) | 
 |         return symbolic_helper._reducesum_helper( | 
 |             g, cond_op, axes_i=dim, keepdims_i=keepdim | 
 |         ) | 
 |     else: | 
 |         return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::embedding_bag") | 
 | @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") | 
 | @_beartype.beartype | 
 | def embedding_bag( | 
 |     g: jit_utils.GraphContext, | 
 |     embedding_matrix, | 
 |     indices, | 
 |     offsets, | 
 |     scale_grad_by_freq, | 
 |     mode, | 
 |     sparse, | 
 |     per_sample_weights, | 
 |     include_last_offset, | 
 |     padding_idx, | 
 | ): | 
 |     if scale_grad_by_freq and GLOBALS.export_training: | 
 |         return symbolic_helper._onnx_unsupported( | 
 |             "embedding_bag with scale_grad_by_freq for training mode" | 
 |         ) | 
 |     if padding_idx is not None and padding_idx >= 0: | 
 |         raise RuntimeError("embedding_bag with padding_idx") | 
 |  | 
 |     loop_condition = g.op("Constant", value_t=torch.tensor(1)) | 
 |     loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL) | 
 |     zero = g.op("Constant", value_t=torch.tensor([0])) | 
 |  | 
 |     indices_len = symbolic_helper._unsqueeze_helper( | 
 |         g, | 
 |         symbolic_helper._size_helper( | 
 |             g, indices, g.op("Constant", value_t=torch.tensor(0)) | 
 |         ), | 
 |         [0], | 
 |     ) | 
 |     if not include_last_offset: | 
 |         offsets = [offsets, indices_len] | 
 |         offsets = g.op("Concat", *offsets, axis_i=0) | 
 |  | 
 |     # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by | 
 |     # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. | 
 |     # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. | 
 |     offsets_starts = symbolic_helper._slice_helper( | 
 |         g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1] | 
 |     ) | 
 |     offsets_ends = symbolic_helper._slice_helper( | 
 |         g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1] | 
 |     ) | 
 |  | 
 |     loop_len = symbolic_helper._size_helper( | 
 |         g, offsets_ends, g.op("Constant", value_t=torch.tensor(0)) | 
 |     ) | 
 |  | 
 |     loop, (loop_context,), _ = jit_utils.add_op_with_blocks( | 
 |         g, "Loop", loop_len, loop_condition, n_blocks=1 | 
 |     ) | 
 |     loop_block = loop_context.block | 
 |  | 
 |     # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return | 
 |     block_input_iter = utils._add_input_to_block(loop_block) | 
 |     cond = utils._add_input_to_block(loop_block) | 
 |  | 
 |     indices_start = loop_context.op( | 
 |         "Gather", offsets_starts, block_input_iter, axis_i=0 | 
 |     ) | 
 |     indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0) | 
 |     indices_start = symbolic_helper._unsqueeze_helper(loop_context, indices_start, [0]) | 
 |     indices_end = symbolic_helper._unsqueeze_helper(loop_context, indices_end, [0]) | 
 |  | 
 |     indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero) | 
 |     embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0) | 
 |     if not symbolic_helper._is_none(per_sample_weights): | 
 |         per_sample_weights_row = loop_context.op( | 
 |             "Slice", per_sample_weights, indices_start, indices_end, zero | 
 |         ) | 
 |         per_sample_weights_row = symbolic_helper._unsqueeze_helper( | 
 |             loop_context, per_sample_weights_row, [1] | 
 |         ) | 
 |         embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row) | 
 |     if mode == 0: | 
 |         embeddings = symbolic_helper._reducesum_helper( | 
 |             loop_context, embeddings, axes_i=[0], keepdims_i=0 | 
 |         ) | 
 |     elif mode == 1: | 
 |         embeddings = loop_context.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) | 
 |     else: | 
 |         embeddings = loop_context.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) | 
 |  | 
 |     cond_out = loop_context.op( | 
 |         "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL | 
 |     ) | 
 |     utils._add_output_to_block(loop_block, cond_out) | 
 |     utils._add_output_to_block(loop_block, embeddings) | 
 |  | 
 |     # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. | 
 |     # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. | 
 |     return loop.node().output(), None, None, None | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::embedding_renorm") | 
 | @symbolic_helper.parse_args("v", "v", "f", "f") | 
 | @_beartype.beartype | 
 | def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type): | 
 |     unique_indices = g.op("Unique", indices) | 
 |     partial_weight = g.op("Gather", weight, unique_indices) | 
 |     norm_i = int(norm_type) | 
 |     if norm_i == 1: | 
 |         norm_type = "ReduceL1" | 
 |     elif norm_i == 2: | 
 |         norm_type = "ReduceL2" | 
 |     else: | 
 |         raise errors.SymbolicValueError( | 
 |             f"Unsupported: ONNX export of embedding_renorm with norm: {norm_i}. " | 
 |             "Only 1. and 2. are supported.", | 
 |             weight, | 
 |         ) | 
 |     partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1) | 
 |     # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177 | 
 |     # Add 1e-7 to prevent division by zero. | 
 |     partial_weight_norm_ = g.op( | 
 |         "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7)) | 
 |     ) | 
 |     max_norm = torch.tensor(max_norm) | 
 |     scales = g.op("Div", max_norm, partial_weight_norm_) | 
 |     partial_weight_renorm = g.op("Mul", partial_weight, scales) | 
 |     partial_weight_renorm = g.op( | 
 |         "Where", | 
 |         g.op("Greater", partial_weight_norm, max_norm), | 
 |         partial_weight_renorm, | 
 |         partial_weight, | 
 |     ) | 
 |     return g.op( | 
 |         "ScatterND", | 
 |         weight, | 
 |         symbolic_helper._unsqueeze_helper(g, unique_indices, [1]), | 
 |         partial_weight_renorm, | 
 |     ) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::chunk") | 
 | @_beartype.beartype | 
 | def chunk(g: jit_utils.GraphContext, self, chunks, dim): | 
 |     # Calculate chunk size for dynamic chunk | 
 |     dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) | 
 |     chunk_size_s = g.op( | 
 |         "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)) | 
 |     ) | 
 |     chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) | 
 |     # Create splits vector | 
 |     chunk_vec = [ | 
 |         opset9.expand(g, chunk_size, chunk_size_s, None), | 
 |         g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)), | 
 |     ] | 
 |     chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) | 
 |     return split(g, self, chunk_vec, dim) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::normal") | 
 | @_beartype.beartype | 
 | def normal( | 
 |     g: jit_utils.GraphContext, | 
 |     mean, | 
 |     std, | 
 |     sizes=None, | 
 |     generator=None, | 
 |     dtype=None, | 
 |     layout=None, | 
 |     device=None, | 
 |     pin_memory=None, | 
 | ): | 
 |     # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a | 
 |     # scale-location transformation of that distribution, which has mean μ and variance σ's square. If x is a sample | 
 |     # from a mean 0 and variance 1 distribution then | 
 |     #       σx+μ | 
 |     # is a sample with mean μ and variance σ's square. | 
 |     if sizes is not None and not symbolic_helper._is_none(sizes): | 
 |         mean = opset9.expand(g, mean, sizes, None) | 
 |     result = opset9.mul(g, std, g.op("RandomNormalLike", mean)) | 
 |     return add(g, result, mean) | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::atleast_1d") | 
 | @_beartype.beartype | 
 | def atleast_1d(g: jit_utils.GraphContext, self: torch._C.Value): | 
 |     # NOTE: If it's 0D, reshape to 1D | 
 |  | 
 |     # NOTE: self could be a packed list or a tensor | 
 |     if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): | 
 |         tensor_list = symbolic_helper._unpack_list(self) | 
 |         new_tensor_list = [] | 
 |         for tensor in tensor_list: | 
 |             new_tensor = tensor | 
 |             tensor_rank = symbolic_helper._get_tensor_rank(tensor) | 
 |             if tensor_rank == 0: | 
 |                 new_tensor = symbolic_helper._reshape_helper( | 
 |                     g, new_tensor, g.op("Constant", value_t=torch.tensor([1])) | 
 |                 ) | 
 |             new_tensor_list.append(new_tensor) | 
 |         return g.op("SequenceConstruct", *new_tensor_list) | 
 |  | 
 |     tensor_rank = symbolic_helper._get_tensor_rank(self) | 
 |     if tensor_rank == 0: | 
 |         self = symbolic_helper._reshape_helper( | 
 |             g, self, g.op("Constant", value_t=torch.tensor([1])) | 
 |         ) | 
 |     return self | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::atleast_2d") | 
 | @_beartype.beartype | 
 | def atleast_2d(g: jit_utils.GraphContext, self: torch._C.Value): | 
 |     # NOTE: If it's 0D, reshape to 2D | 
 |     #       If it's 1D, unsqueeze to 2D | 
 |  | 
 |     # NOTE: self could be a packed list or a tensor | 
 |     if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): | 
 |         tensor_list = symbolic_helper._unpack_list(self) | 
 |         new_tensor_list = [] | 
 |         for tensor in tensor_list: | 
 |             new_tensor = tensor | 
 |             tensor_rank = symbolic_helper._get_tensor_rank(tensor) | 
 |             if tensor_rank == 0: | 
 |                 new_tensor = symbolic_helper._reshape_helper( | 
 |                     g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1])) | 
 |                 ) | 
 |             elif tensor_rank == 1: | 
 |                 new_tensor = symbolic_helper._unsqueeze_helper( | 
 |                     g, new_tensor, axes_i=[0] | 
 |                 ) | 
 |             new_tensor_list.append(new_tensor) | 
 |         return g.op("SequenceConstruct", *new_tensor_list) | 
 |  | 
 |     tensor_rank = symbolic_helper._get_tensor_rank(self) | 
 |     if tensor_rank == 0: | 
 |         self = symbolic_helper._reshape_helper( | 
 |             g, self, g.op("Constant", value_t=torch.tensor([1, 1])) | 
 |         ) | 
 |     elif tensor_rank == 1: | 
 |         self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) | 
 |     return self | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::atleast_3d") | 
 | @_beartype.beartype | 
 | def atleast_3d(g: jit_utils.GraphContext, self: torch._C.Value): | 
 |     # NOTE: If it's 0D, reshape to 3D | 
 |     #       If it's 1D, unsqueeze to 3D | 
 |     #       If it's 2D, unsqueeze to 3D | 
 |  | 
 |     # NOTE: self could be a packed list or a tensor | 
 |     if symbolic_helper._is_value(self) and symbolic_helper._is_packed_list(self): | 
 |         tensor_list = symbolic_helper._unpack_list(self) | 
 |         new_tensor_list = [] | 
 |         for tensor in tensor_list: | 
 |             new_tensor = tensor | 
 |             tensor_rank = symbolic_helper._get_tensor_rank(tensor) | 
 |             if tensor_rank == 0: | 
 |                 new_tensor = symbolic_helper._reshape_helper( | 
 |                     g, new_tensor, g.op("Constant", value_t=torch.tensor([1, 1, 1])) | 
 |                 ) | 
 |             elif tensor_rank == 1: | 
 |                 new_tensor = symbolic_helper._unsqueeze_helper( | 
 |                     g, new_tensor, axes_i=[0] | 
 |                 ) | 
 |                 new_tensor = symbolic_helper._unsqueeze_helper( | 
 |                     g, new_tensor, axes_i=[-1] | 
 |                 ) | 
 |             elif tensor_rank == 2: | 
 |                 new_tensor = symbolic_helper._unsqueeze_helper( | 
 |                     g, new_tensor, axes_i=[-1] | 
 |                 ) | 
 |             new_tensor_list.append(new_tensor) | 
 |         return g.op("SequenceConstruct", *new_tensor_list) | 
 |  | 
 |     tensor_rank = symbolic_helper._get_tensor_rank(self) | 
 |     if tensor_rank == 0: | 
 |         self = symbolic_helper._reshape_helper( | 
 |             g, self, g.op("Constant", value_t=torch.tensor([1, 1, 1])) | 
 |         ) | 
 |     elif tensor_rank == 1: | 
 |         self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[0]) | 
 |         self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) | 
 |     elif tensor_rank == 2: | 
 |         self = symbolic_helper._unsqueeze_helper(g, self, axes_i=[-1]) | 
 |     return self | 
 |  | 
 |  | 
 | @_onnx_symbolic("prim::ConstantChunk") | 
 | @_beartype.beartype | 
 | def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): | 
 |     input_shape = g.op("Shape", self) | 
 |     axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) | 
 |     input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0) | 
 |     start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) | 
 |     chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long)) | 
 |     chunk_size_minus_1 = g.op( | 
 |         "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long) | 
 |     ) | 
 |     input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1) | 
 |     chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size) | 
 |     res = [] | 
 |     for i in range(chunks): | 
 |         index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long)) | 
 |         end = g.op("Mul", chunk_dim, index) | 
 |         res.append(g.op("Slice", self, start, end, axis)) | 
 |         start = end | 
 |     return res | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::hstack") | 
 | @_beartype.beartype | 
 | def hstack(g: jit_utils.GraphContext, tensor_list: _C.Value): | 
 |     tensor_list = atleast_1d(g, tensor_list) | 
 |     first_tensor = g.op( | 
 |         "SequenceAt", | 
 |         tensor_list, | 
 |         g.op("Constant", value_t=torch.tensor(0, dtype=torch.long)), | 
 |     ) | 
 |     first_tensor_shape = g.op("Shape", first_tensor) | 
 |     first_tensor_dim = g.op("Size", first_tensor_shape) | 
 |  | 
 |     const_one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)) | 
 |     equal_to_one = g.op("Equal", first_tensor_dim, const_one) | 
 |  | 
 |     ( | 
 |         if_op_greater, | 
 |         (if_context_equal, else_context_equal), | 
 |         _, | 
 |     ) = jit_utils.add_op_with_blocks(g, "If", equal_to_one, n_blocks=2, outputs=1) | 
 |     result_if = if_context_equal.op( | 
 |         "ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0 | 
 |     ) | 
 |     utils._add_output_to_block(if_context_equal.block, result_if) | 
 |     result_else = else_context_equal.op( | 
 |         "ConcatFromSequence", tensor_list, axis_i=1, new_axis_i=0 | 
 |     ) | 
 |     utils._add_output_to_block(else_context_equal.block, result_else) | 
 |     result = if_op_greater.node().output() | 
 |  | 
 |     return result | 
 |  | 
 |  | 
 | @_onnx_symbolic("aten::vstack") | 
 | @_beartype.beartype | 
 | def vstack(g: jit_utils.GraphContext, tensor_list: _C.Value): | 
 |     tensor_list = atleast_2d(g, tensor_list) | 
 |     return g.op("ConcatFromSequence", tensor_list, axis_i=0, new_axis_i=0) |