|  | import functools | 
|  | import sys | 
|  | import warnings | 
|  | from typing import Callable | 
|  |  | 
|  | import torch | 
|  | import torch._C._onnx as _C_onnx | 
|  | import torch.onnx | 
|  | from torch import _C | 
|  |  | 
|  | # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics | 
|  | from torch.onnx import (  # noqa: F401 | 
|  | _constants, | 
|  | _patch_torch, | 
|  | _type_utils, | 
|  | errors, | 
|  | symbolic_helper, | 
|  | symbolic_opset9 as opset9, | 
|  | ) | 
|  | from torch.onnx._globals import GLOBALS | 
|  | from torch.onnx._internal import _beartype, jit_utils, registration | 
|  |  | 
|  | # EDITING THIS FILE? READ THIS FIRST! | 
|  | # see Note [Edit Symbolic Files] in README.md | 
|  |  | 
|  | # This file exports ONNX ops for opset 10 | 
|  | # Opset 10 is supported by ONNX release 1.5.0 | 
|  | # release on 04/24/19 | 
|  |  | 
|  |  | 
|  | __all__ = [ | 
|  | "dequantize", | 
|  | "div", | 
|  | "embedding_bag", | 
|  | "fake_quantize_per_tensor_affine", | 
|  | "flip", | 
|  | "fmod", | 
|  | "isfinite", | 
|  | "isinf", | 
|  | "nan_to_num", | 
|  | "quantize_per_tensor", | 
|  | "quantized_add_relu", | 
|  | "quantized_add", | 
|  | "quantized_cat", | 
|  | "quantized_conv1d_relu", | 
|  | "quantized_conv2d_relu", | 
|  | "quantized_conv2d", | 
|  | "quantized_group_norm", | 
|  | "quantized_hardswish", | 
|  | "quantized_instance_norm", | 
|  | "quantized_layer_norm", | 
|  | "quantized_leaky_relu", | 
|  | "quantized_linear", | 
|  | "quantized_mul", | 
|  | "quantized_sigmoid", | 
|  | "slice", | 
|  | "sort", | 
|  | "topk", | 
|  | ] | 
|  |  | 
|  |  | 
|  | _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) | 
|  |  | 
|  |  | 
|  | def _apply_params(*args, **kwargs): | 
|  | """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" | 
|  |  | 
|  | def _apply(fn): | 
|  | return fn(*args, **kwargs) | 
|  |  | 
|  | return _apply | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::div") | 
|  | @_beartype.beartype | 
|  | def div(g: jit_utils.GraphContext, self, other, *args): | 
|  | if len(args) == 0: | 
|  | return opset9.true_divide(g, self, other) | 
|  | else: | 
|  | return _div_rounding_mode(g, self, other, *args) | 
|  |  | 
|  |  | 
|  | @symbolic_helper.parse_args("v", "v", "s") | 
|  | @_beartype.beartype | 
|  | def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): | 
|  | if rounding_mode == "floor": | 
|  | return _floor_divide(g, self, other) | 
|  | else: | 
|  | return opset9._div_rounding_mode(g, self, other, rounding_mode) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::_floor_divide") | 
|  | @_beartype.beartype | 
|  | def _floor_divide(g: jit_utils.GraphContext, self, other): | 
|  | if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): | 
|  | out = opset9.true_divide(g, self, other) | 
|  | return g.op("Floor", out) | 
|  | else: | 
|  | # Integer division does trunction rounding | 
|  | div = g.op("Div", self, other) | 
|  | # Division is negative if: self < 0 != other < 0 | 
|  | zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) | 
|  | negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) | 
|  |  | 
|  | # For negative numbers with self % other != 0, subtract 1 to round down instead of up | 
|  | mod = g.op("Mod", self, other, fmod_i=0) | 
|  | fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) | 
|  |  | 
|  | one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) | 
|  | fixup = g.op("Sub", div, one) | 
|  | return g.op("Where", fixup_mask, fixup, div) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::sort") | 
|  | @symbolic_helper.parse_args("v", "i", "i", "none") | 
|  | @_beartype.beartype | 
|  | def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): | 
|  | return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::topk") | 
|  | @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") | 
|  | @_beartype.beartype | 
|  | def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): | 
|  | return symbolic_helper._topk_helper( | 
|  | g, self, k, dim, largest=largest, sorted=sorted, out=out | 
|  | ) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic( | 
|  | "aten::max_pool1d", | 
|  | decorate=[ | 
|  | _apply_params( | 
|  | "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False | 
|  | ) | 
|  | ], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::max_pool2d", | 
|  | decorate=[ | 
|  | _apply_params( | 
|  | "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False | 
|  | ) | 
|  | ], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::max_pool3d", | 
|  | decorate=[ | 
|  | _apply_params( | 
|  | "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False | 
|  | ) | 
|  | ], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::max_pool1d_with_indices", | 
|  | decorate=[ | 
|  | _apply_params( | 
|  | "max_pool1d_with_indices", | 
|  | torch.nn.modules.utils._single, | 
|  | 1, | 
|  | return_indices=True, | 
|  | ) | 
|  | ], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::max_pool2d_with_indices", | 
|  | decorate=[ | 
|  | _apply_params( | 
|  | "max_pool2d_with_indices", | 
|  | torch.nn.modules.utils._pair, | 
|  | 2, | 
|  | return_indices=True, | 
|  | ) | 
|  | ], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::max_pool3d_with_indices", | 
|  | decorate=[ | 
|  | _apply_params( | 
|  | "max_pool3d_with_indices", | 
|  | torch.nn.modules.utils._triple, | 
|  | 3, | 
|  | return_indices=True, | 
|  | ) | 
|  | ], | 
|  | ) | 
|  | @_beartype.beartype | 
|  | def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool): | 
|  | @symbolic_helper.quantized_args(True, False, False, False, False, False) | 
|  | @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") | 
|  | def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): | 
|  | if not stride: | 
|  | stride = kernel_size | 
|  | kwargs = { | 
|  | "kernel_shape_i": tuple_fn(kernel_size), | 
|  | "pads_i": tuple_fn(padding) * 2, | 
|  | "strides_i": tuple_fn(stride), | 
|  | "ceil_mode_i": ceil_mode, | 
|  | } | 
|  | if set(tuple_fn(dilation)) != {1}: | 
|  | kwargs["dilations_i"] = tuple_fn(dilation) | 
|  | # easy but hacky way to get flattened indices values | 
|  | # to be used to convert the indices values to non-flattened. | 
|  | # In ONNX the indices are computed as a flatten 1-D tensor, | 
|  | # so the values in indices are in [0, N x C x D1 x ... x Dn). | 
|  | # To convert the indices to the same format used by Pytorch, | 
|  | # we first execute a maxpool with a kernel and stride of 1 on the same input. | 
|  | # This will result in a tensor of indices in which each index will have it's own value. | 
|  | # Using this tensor as a reference, we extract the first index of each axis and subtract | 
|  | # it from each index of this axis in the indices to convert. | 
|  | # This step will result in a tensor were each dimension has values of indices within | 
|  | # the dimension it is in. | 
|  | # For more information : | 
|  | # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 | 
|  | if return_indices: | 
|  | r, indices = g.op("MaxPool", input, outputs=2, **kwargs) | 
|  | _, flattened_indices = g.op( | 
|  | "MaxPool", | 
|  | input, | 
|  | outputs=2, | 
|  | kernel_shape_i=[1 for _ in range(ndims)], | 
|  | strides_i=[1 for _ in range(ndims)], | 
|  | ) | 
|  | # convert indices to have non-flattened indices values | 
|  | s = symbolic_helper._slice_helper( | 
|  | g, | 
|  | flattened_indices, | 
|  | axes=[2 + i for i in range(ndims)], | 
|  | starts=tuple_fn(0), | 
|  | ends=tuple_fn(1), | 
|  | ) | 
|  | indices = opset9.sub(g, indices, s) | 
|  | return r, indices | 
|  | else: | 
|  | r = g.op("MaxPool", input, outputs=1, **kwargs) | 
|  | return r | 
|  |  | 
|  | return symbolic_fn | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic( | 
|  | "aten::avg_pool1d", | 
|  | decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::avg_pool2d", | 
|  | decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::avg_pool3d", | 
|  | decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)], | 
|  | ) | 
|  | @_beartype.beartype | 
|  | def _avg_pool(name, tuple_fn): | 
|  | # Although onnx::AvgPool provides count_include_pad and ceil_mode, | 
|  | # The corner case of Average Pooling with ceil_mode on | 
|  | # PyTorch allows sliding window go off bound, which leads to | 
|  | # this accommodation. | 
|  | # More detail on https://github.com/pytorch/pytorch/issues/57178 | 
|  | return opset9._avg_pool(name, tuple_fn) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic( | 
|  | "aten::upsample_nearest1d", | 
|  | decorate=[_apply_params("upsample_nearest1d", 3, "nearest")], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::upsample_nearest2d", | 
|  | decorate=[_apply_params("upsample_nearest2d", 4, "nearest")], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::upsample_nearest3d", | 
|  | decorate=[_apply_params("upsample_nearest3d", 5, "nearest")], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::upsample_linear1d", | 
|  | decorate=[_apply_params("upsample_linear1d", 3, "linear")], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::upsample_bilinear2d", | 
|  | decorate=[_apply_params("upsample_bilinear2d", 4, "linear")], | 
|  | ) | 
|  | @_onnx_symbolic( | 
|  | "aten::upsample_trilinear3d", | 
|  | decorate=[_apply_params("upsample_trilinear3d", 5, "linear")], | 
|  | ) | 
|  | @_beartype.beartype | 
|  | def _interpolate(name, dim, interpolate_mode): | 
|  | @symbolic_helper.quantized_args(True, False, False) | 
|  | @_beartype.beartype | 
|  | def symbolic_fn(g, input, output_size, *args): | 
|  | scales, align_corners = symbolic_helper._get_interpolate_attributes( | 
|  | g, interpolate_mode, args | 
|  | ) | 
|  | symbolic_helper._interpolate_warning(interpolate_mode) | 
|  | align_corners = symbolic_helper._maybe_get_scalar(align_corners) | 
|  | if align_corners: | 
|  | return symbolic_helper._unimplemented(name, "align_corners == True", input) | 
|  | if scales is None: | 
|  | scales = symbolic_helper._interpolate_size_to_scales( | 
|  | g, input, output_size, dim | 
|  | ) | 
|  | return g.op("Resize", input, scales, mode_s=interpolate_mode) | 
|  |  | 
|  | return symbolic_fn | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::__interpolate") | 
|  | @_beartype.beartype | 
|  | def __interpolate( | 
|  | g: jit_utils.GraphContext, | 
|  | input, | 
|  | size, | 
|  | scale_factor, | 
|  | mode, | 
|  | align_corners, | 
|  | recompute_scale_factor, | 
|  | antialias, | 
|  | ): | 
|  | scales, mode = symbolic_helper._interpolate_get_scales_and_mode( | 
|  | g, input, size, scale_factor, mode, align_corners | 
|  | ) | 
|  | return g.op("Resize", input, scales, mode_s=mode) | 
|  |  | 
|  |  | 
|  | @_beartype.beartype | 
|  | def _slice( | 
|  | g: jit_utils.GraphContext, | 
|  | input, | 
|  | axes, | 
|  | starts, | 
|  | ends, | 
|  | steps=None, | 
|  | dynamic_slice=False, | 
|  | ): | 
|  | if dynamic_slice: | 
|  | starts = symbolic_helper._unsqueeze_helper(g, starts, [0]) | 
|  | ends = symbolic_helper._unsqueeze_helper(g, ends, [0]) | 
|  | if isinstance(axes, int): | 
|  | axes = g.op("Constant", value_t=torch.tensor(axes)) | 
|  | axes = symbolic_helper._unsqueeze_helper(g, axes, [0]) | 
|  | else: | 
|  | assert len(starts) == len(ends) | 
|  | assert len(starts) == len(axes) | 
|  | assert steps is None or len(starts) == len(steps) | 
|  | if ( | 
|  | len(starts) == 1 | 
|  | and starts[0] == 0 | 
|  | and ends[0] == _constants.INT64_MAX | 
|  | and (steps is None or (len(steps) == 1 and steps[0] == 1)) | 
|  | ): | 
|  | return input | 
|  | axes = g.op("Constant", value_t=torch.tensor(axes)) | 
|  | starts = g.op("Constant", value_t=torch.tensor(starts)) | 
|  | ends = g.op("Constant", value_t=torch.tensor(ends)) | 
|  | if steps is None: | 
|  | return g.op("Slice", input, starts, ends, axes) | 
|  | steps = g.op("Constant", value_t=torch.tensor(steps)) | 
|  | return g.op("Slice", input, starts, ends, axes, steps) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::slice") | 
|  | @_beartype.beartype | 
|  | def slice(g: jit_utils.GraphContext, self, *args): | 
|  | if len(args) == 4: | 
|  | # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor | 
|  | dim, start, end, step = args | 
|  | elif len(args) == 3: | 
|  | # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] | 
|  | start, end, step = args | 
|  | dim = 0 | 
|  | else: | 
|  | raise errors.SymbolicValueError("Unknown aten::slice signature", self) | 
|  | is_start_none = start.node().kind() == "prim::Constant" and isinstance( | 
|  | start.type(), _C.NoneType | 
|  | ) | 
|  | is_end_none = end.node().kind() == "prim::Constant" and isinstance( | 
|  | end.type(), _C.NoneType | 
|  | ) | 
|  | is_start_onnx_const = start.node().kind() == "onnx::Constant" | 
|  | is_end_onnx_const = end.node().kind() == "onnx::Constant" | 
|  | step = symbolic_helper._parse_arg(step, "i") | 
|  | if ( | 
|  | (not is_start_none and not is_start_onnx_const) | 
|  | or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const) | 
|  | or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant") | 
|  | ): | 
|  | dynamic_slice = True | 
|  | if is_start_none: | 
|  | start = g.op("Constant", value_t=torch.tensor(0)) | 
|  | if is_end_none: | 
|  | end = g.op("Constant", value_t=torch.tensor(_constants.INT64_MAX)) | 
|  | else: | 
|  | start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")] | 
|  | end = [ | 
|  | _constants.INT64_MAX | 
|  | if is_end_none | 
|  | else symbolic_helper._parse_arg(end, "i") | 
|  | ] | 
|  | dim = [symbolic_helper._parse_arg(dim, "i")] | 
|  | dynamic_slice = False | 
|  | return symbolic_helper._slice_helper( | 
|  | g, | 
|  | self, | 
|  | axes=dim, | 
|  | starts=start, | 
|  | ends=end, | 
|  | steps=[step], | 
|  | dynamic_slice=dynamic_slice, | 
|  | ) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::flip") | 
|  | @symbolic_helper.parse_args("v", "is") | 
|  | @_beartype.beartype | 
|  | def flip(g: jit_utils.GraphContext, input, dims): | 
|  | return symbolic_helper._slice_helper( | 
|  | g, | 
|  | input, | 
|  | axes=dims, | 
|  | starts=[-1] * len(dims), | 
|  | ends=[-_constants.INT64_MAX] * len(dims), | 
|  | steps=[-1] * len(dims), | 
|  | ) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::fmod") | 
|  | @_beartype.beartype | 
|  | def fmod(g: jit_utils.GraphContext, input, other): | 
|  | return g.op("Mod", input, other, fmod_i=1) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::embedding_bag") | 
|  | @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") | 
|  | @_beartype.beartype | 
|  | def embedding_bag( | 
|  | g: jit_utils.GraphContext, | 
|  | embedding_matrix, | 
|  | indices, | 
|  | offsets, | 
|  | scale_grad_by_freq, | 
|  | mode, | 
|  | sparse, | 
|  | per_sample_weights, | 
|  | include_last_offset, | 
|  | padding_idx, | 
|  | ): | 
|  | if scale_grad_by_freq and GLOBALS.export_training: | 
|  | return symbolic_helper._onnx_unsupported( | 
|  | "embedding_bag with scale_grad_by_freq for training mode" | 
|  | ) | 
|  | if padding_idx is not None and padding_idx >= 0: | 
|  | raise RuntimeError("embedding_bag with padding_idx") | 
|  |  | 
|  | warnings.warn( | 
|  | "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " | 
|  | "Please use opset 11 or higher to export model for dynamic input shape.'" | 
|  | ) | 
|  | offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) | 
|  | if offsets_dim_0 is not None: | 
|  | if include_last_offset: | 
|  | offset_len = offsets_dim_0 - 1 | 
|  | offsets_extended = offsets | 
|  | else: | 
|  | offset_len = offsets_dim_0 | 
|  | offsets_extended = [ | 
|  | offsets, | 
|  | g.op("Constant", value_t=torch.tensor([sys.maxsize])), | 
|  | ] | 
|  | offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) | 
|  | list_ = [] | 
|  | for i in range(offset_len): | 
|  | start_ = symbolic_helper._unsqueeze_helper( | 
|  | g, | 
|  | opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), | 
|  | [0], | 
|  | ) | 
|  | end_ = symbolic_helper._unsqueeze_helper( | 
|  | g, | 
|  | opset9.select( | 
|  | g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) | 
|  | ), | 
|  | [0], | 
|  | ) | 
|  | axes_ = g.op("Constant", value_t=torch.tensor([0])) | 
|  | indices_row = g.op("Slice", indices, start_, end_, axes_) | 
|  |  | 
|  | embeddings = g.op("Gather", embedding_matrix, indices_row) | 
|  | if not symbolic_helper._is_none(per_sample_weights): | 
|  | per_sample_weights_row = g.op( | 
|  | "Slice", per_sample_weights, start_, end_, axes_ | 
|  | ) | 
|  | per_sample_weights_row = symbolic_helper._unsqueeze_helper( | 
|  | g, per_sample_weights_row, [1] | 
|  | ) | 
|  | embeddings = g.op("Mul", embeddings, per_sample_weights_row) | 
|  | if mode == 0: | 
|  | embeddings = symbolic_helper._reducesum_helper( | 
|  | g, embeddings, axes_i=[0], keepdims_i=0 | 
|  | ) | 
|  | elif mode == 1: | 
|  | embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) | 
|  | else: | 
|  | embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) | 
|  |  | 
|  | embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) | 
|  | list_.append(embeddings) | 
|  |  | 
|  | output = g.op("Concat", *list_, axis_i=0) | 
|  | # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. | 
|  | # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. | 
|  | return output, None, None, None | 
|  | else: | 
|  | return symbolic_helper._onnx_unsupported( | 
|  | "embedding_bag with unknown shape of offsets for opset 10 is not supported. " | 
|  | "please use opset 11 or higher." | 
|  | ) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::fake_quantize_per_tensor_affine") | 
|  | @symbolic_helper.parse_args("v", "v", "v", "i", "i") | 
|  | @_beartype.beartype | 
|  | def fake_quantize_per_tensor_affine( | 
|  | g: jit_utils.GraphContext, | 
|  | inputs, | 
|  | scale, | 
|  | zero_point, | 
|  | quant_min=-128, | 
|  | quant_max=127, | 
|  | ): | 
|  | # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). | 
|  | #   https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 | 
|  | if (quant_min, quant_max) == (0, 127): | 
|  | symbolic_helper._onnx_opset_unsupported_detailed( | 
|  | "fake_quantize_per_tensor_affine", | 
|  | 10, | 
|  | 13, | 
|  | "Quantize range (0, 127) not supported, requires opset 13 Clip", | 
|  | inputs, | 
|  | ) | 
|  | if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: | 
|  | raise errors.SymbolicValueError( | 
|  | f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " | 
|  | f"Got ({quant_min}, {quant_max})", | 
|  | inputs, | 
|  | ) | 
|  | scale = symbolic_helper._maybe_get_scalar(scale) | 
|  | if scale is None: | 
|  | symbolic_helper._onnx_opset_unsupported_detailed( | 
|  | "fake_quantize_per_tensor_affine", | 
|  | 10, | 
|  | 13, | 
|  | "Non-constant scale not supported", | 
|  | inputs, | 
|  | ) | 
|  | scale = scale.float().data  # Avoid exporter generating double type | 
|  | if quant_min == 0: | 
|  | zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) | 
|  | else: | 
|  | zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) | 
|  | return g.op( | 
|  | "DequantizeLinear", | 
|  | g.op("QuantizeLinear", inputs, scale, zero_point), | 
|  | scale, | 
|  | zero_point, | 
|  | ) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::isinf") | 
|  | @_beartype.beartype | 
|  | def isinf(g: jit_utils.GraphContext, input): | 
|  | return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE)) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::isfinite") | 
|  | @_beartype.beartype | 
|  | def isfinite(g: jit_utils.GraphContext, input): | 
|  | inf_node = isinf(g, input) | 
|  | nan_node = opset9.isnan(g, input) | 
|  | return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::quantize_per_tensor") | 
|  | @_beartype.beartype | 
|  | def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): | 
|  | dtype = symbolic_helper._get_const(dtype, "i", "dtype") | 
|  | # TODO(justinchuby): Extract all the cast ops into a helper function. | 
|  | zero_point = g.op( | 
|  | "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() | 
|  | ) | 
|  | scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) | 
|  | return symbolic_helper.quantize_helper(g, input, scale, zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::dequantize") | 
|  | @_beartype.beartype | 
|  | def dequantize(g: jit_utils.GraphContext, input): | 
|  | return symbolic_helper.dequantize_helper(g, input)[0] | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("aten::nan_to_num") | 
|  | @symbolic_helper.parse_args("v", "f", "f", "f") | 
|  | @_beartype.beartype | 
|  | def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): | 
|  | # Cannot create a int type tensor with inf/nan values, so we simply | 
|  | # return the original tensor | 
|  | if not symbolic_helper._is_fp(input): | 
|  | return input | 
|  | input_dtype = _type_utils.JitScalarType.from_value(input).dtype() | 
|  | if nan is None: | 
|  | nan = 0.0 | 
|  | nan_cond = opset9.isnan(g, input) | 
|  | nan_result = g.op( | 
|  | "Where", | 
|  | nan_cond, | 
|  | g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), | 
|  | input, | 
|  | ) | 
|  |  | 
|  | # For None values of posinf, neginf we use the greatest/lowest finite | 
|  | # value representable by input’s dtype. | 
|  | finfo = torch.finfo(input_dtype) | 
|  | if posinf is None: | 
|  | posinf = finfo.max | 
|  | posinf_cond = opset9.logical_and( | 
|  | g, | 
|  | isinf(g, nan_result), | 
|  | opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), | 
|  | ) | 
|  | nan_posinf_result = g.op( | 
|  | "Where", | 
|  | posinf_cond, | 
|  | g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), | 
|  | nan_result, | 
|  | ) | 
|  |  | 
|  | if neginf is None: | 
|  | neginf = finfo.min | 
|  | neginf_cond = opset9.logical_and( | 
|  | g, | 
|  | isinf(g, nan_posinf_result), | 
|  | opset9.lt( | 
|  | g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) | 
|  | ), | 
|  | ) | 
|  | return g.op( | 
|  | "Where", | 
|  | neginf_cond, | 
|  | g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), | 
|  | nan_posinf_result, | 
|  | ) | 
|  |  | 
|  |  | 
|  | # Quantized symbolics --------------------------------------------------------- | 
|  | # https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export | 
|  | # Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were | 
|  | # introduced in opset version 10. | 
|  | @_onnx_symbolic("quantized::linear") | 
|  | @_beartype.beartype | 
|  | def quantized_linear( | 
|  | g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point | 
|  | ): | 
|  | input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | 
|  | weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | 
|  | q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | 
|  | bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | 
|  |  | 
|  | output = opset9.linear(g, input, weight, bias) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::add") | 
|  | @_beartype.beartype | 
|  | def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  | y, _, _, _ = symbolic_helper.dequantize_helper(g, y) | 
|  |  | 
|  | output = opset9.add(g, x, y) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::add_relu") | 
|  | @_beartype.beartype | 
|  | def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  | y, _, _, _ = symbolic_helper.dequantize_helper(g, y) | 
|  |  | 
|  | output = opset9.add(g, x, y) | 
|  | output = opset9.relu(g, output) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::mul") | 
|  | @_beartype.beartype | 
|  | def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  | y, _, _, _ = symbolic_helper.dequantize_helper(g, y) | 
|  |  | 
|  | output = opset9.mul(g, x, y) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::hardswish") | 
|  | @_beartype.beartype | 
|  | def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  |  | 
|  | output = opset9.hardswish(g, x) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::sigmoid") | 
|  | @_beartype.beartype | 
|  | def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  |  | 
|  | output = opset9.sigmoid(g, x) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::leaky_relu") | 
|  | @_beartype.beartype | 
|  | def quantized_leaky_relu( | 
|  | g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point | 
|  | ): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  |  | 
|  | output = opset9.leaky_relu(g, x, negative_slope, inplace) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::layer_norm") | 
|  | @_beartype.beartype | 
|  | def quantized_layer_norm( | 
|  | g: jit_utils.GraphContext, | 
|  | x, | 
|  | normalized_shape, | 
|  | weight, | 
|  | bias, | 
|  | eps, | 
|  | op_scale, | 
|  | op_zero_point, | 
|  | ): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  |  | 
|  | output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::group_norm") | 
|  | @_beartype.beartype | 
|  | def quantized_group_norm( | 
|  | g: jit_utils.GraphContext, | 
|  | x, | 
|  | num_groups, | 
|  | weight, | 
|  | bias, | 
|  | eps, | 
|  | op_scale, | 
|  | op_zero_point, | 
|  | ): | 
|  | x, _, _, _ = symbolic_helper.dequantize_helper(g, x) | 
|  |  | 
|  | output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::instance_norm") | 
|  | @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") | 
|  | @_beartype.beartype | 
|  | def quantized_instance_norm( | 
|  | g: jit_utils.GraphContext, | 
|  | q_input, | 
|  | weight, | 
|  | bias, | 
|  | eps, | 
|  | op_scale, | 
|  | op_zero_point, | 
|  | ): | 
|  | input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) | 
|  |  | 
|  | output = opset9.instance_norm( | 
|  | g, input, weight, bias, None, None, False, 0.0, eps, False | 
|  | ) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::conv1d_relu") | 
|  | @_beartype.beartype | 
|  | def quantized_conv1d_relu( | 
|  | g: jit_utils.GraphContext, | 
|  | q_input, | 
|  | q_weight, | 
|  | bias, | 
|  | stride, | 
|  | padding, | 
|  | dilation, | 
|  | groups, | 
|  | op_scale, | 
|  | op_zero_point, | 
|  | ): | 
|  | input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | 
|  | weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | 
|  | q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | 
|  | bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | 
|  |  | 
|  | output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups) | 
|  | output = opset9.relu(g, output) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::conv2d_relu") | 
|  | @_beartype.beartype | 
|  | def quantized_conv2d_relu( | 
|  | g: jit_utils.GraphContext, | 
|  | q_input, | 
|  | q_weight, | 
|  | bias, | 
|  | stride, | 
|  | padding, | 
|  | dilation, | 
|  | groups, | 
|  | op_scale, | 
|  | op_zero_point, | 
|  | ): | 
|  | input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | 
|  | weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | 
|  | q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | 
|  | bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | 
|  |  | 
|  | output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) | 
|  | output = opset9.relu(g, output) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::conv2d") | 
|  | @_beartype.beartype | 
|  | def quantized_conv2d( | 
|  | g: jit_utils.GraphContext, | 
|  | q_input, | 
|  | q_weight, | 
|  | bias, | 
|  | stride, | 
|  | padding, | 
|  | dilation, | 
|  | groups, | 
|  | op_scale, | 
|  | op_zero_point, | 
|  | ): | 
|  | input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) | 
|  | weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) | 
|  | q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) | 
|  | bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) | 
|  |  | 
|  | output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) | 
|  |  | 
|  | return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) | 
|  |  | 
|  |  | 
|  | @_onnx_symbolic("quantized::cat") | 
|  | @symbolic_helper.parse_args("v", "i", "v", "v") | 
|  | @_beartype.beartype | 
|  | def quantized_cat( | 
|  | g: jit_utils.GraphContext, | 
|  | q_inputs: _C.Value, | 
|  | dim: int, | 
|  | op_scale: _C.Value, | 
|  | op_zero_point: _C.Value, | 
|  | ) -> _C.Value: | 
|  | unpacked_inputs = symbolic_helper._unpack_list(q_inputs) | 
|  | dequantized = [ | 
|  | symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs | 
|  | ] | 
|  | concatenated = g.op("Concat", *dequantized, axis_i=dim) | 
|  | return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) |