|  | from torch.onnx.symbolic_helper import parse_args | 
|  | import torch.onnx.symbolic_helper as sym_help | 
|  | import torch.onnx.symbolic_registry as sym_registry | 
|  | import importlib | 
|  | from inspect import getmembers, isfunction | 
|  |  | 
|  | def register_quantized_ops(domain, version): | 
|  | # Register all the non-quantized ops | 
|  | sym_registry.register_version('', version) | 
|  | # Register all quantized ops | 
|  | module = importlib.import_module('torch.onnx.symbolic_caffe2') | 
|  | sym_registry._symbolic_versions['caffe2'] = module | 
|  | quant_version_ops = getmembers(sym_registry._symbolic_versions['caffe2']) | 
|  | for op in quant_version_ops: | 
|  | if isfunction(op[1]) and not sym_registry.is_registered_op(op[0], domain, version): | 
|  | aten_q_ops = ['relu', '_empty_affine_quantized', 'dequantize', | 
|  | 'quantize_per_tensor', 'upsample_nearest2d', 'avg_pool2d', | 
|  | 'reshape', 'slice', 'cat', 'max_pool2d', 'sigmoid'] | 
|  | if op[0] in aten_q_ops: | 
|  | sym_registry.register_op(op[0], op[1], '', version) | 
|  | sym_registry.register_op(op[0], op[1], domain, version) | 
|  |  | 
|  | def _permute_helper(g, input, axes): | 
|  | quant_args = { | 
|  | "axes_i": axes, | 
|  | "Y_scale_f": input.node()["Y_scale"], | 
|  | "Y_zero_point_i": input.node()["Y_zero_point"], | 
|  | } | 
|  | output = g.op("_caffe2::Int8Transpose", input, **quant_args) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | def nchw2nhwc(g, input): | 
|  | axes = [0, 2, 3, 1] | 
|  | return _permute_helper(g, input, axes) | 
|  |  | 
|  | def nhwc2nchw(g, input): | 
|  | axes = [0, 3, 1, 2] | 
|  | return _permute_helper(g, input, axes) | 
|  |  | 
|  | def linear_prepack(g, weight, bias): | 
|  | # Mapping to a dummy caffe2 prepack node. | 
|  | # During the onnx -> c2 conversion we can look up original weight and bias | 
|  | # from this node | 
|  | output = g.op("_caffe2::WeightPrepack", weight, bias) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v', 'v', 'v', 'f', 'i') | 
|  | def linear(g, input, weight, bias, scale, zero_point): | 
|  | kwargs = { | 
|  | "Y_scale_f": scale, | 
|  | "Y_zero_point_i": zero_point, | 
|  | } | 
|  | output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | def conv_prepack(g, input, weight, bias, stride, padding, dilation, groups): | 
|  | # Mapping to a dummy caffe2 prepack node. | 
|  | # During the onnx -> c2 conversion we can look up original weight and bias | 
|  | # from this node | 
|  | output = g.op("_caffe2::WeightPrepack", input, weight, bias) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'f', 'i') | 
|  | def conv2d(g, input, weight, bias, stride, padding, dilation, groups, scale, zero_point): | 
|  | kernel_size = weight.node()["shape"][1:3] | 
|  | kwargs = { | 
|  | "strides_i": stride, | 
|  | "pads_i": padding + padding, | 
|  | "dilations_i": dilation, | 
|  | "group_i": groups, | 
|  | "kernels_i": kernel_size, | 
|  | "order_s": "NHWC", | 
|  | "Y_scale_f": scale, | 
|  | "Y_zero_point_i": zero_point, | 
|  | } | 
|  | output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'f', 'i') | 
|  | def conv2d_relu(g, input, weight, bias, stride, padding, dilation, groups, scale, zero_point): | 
|  | kernel_size = weight.node()["shape"][1:3] | 
|  | kwargs = { | 
|  | "strides_i": stride, | 
|  | "pads_i": padding + padding, | 
|  | "dilations_i": dilation, | 
|  | "group_i": groups, | 
|  | "kernels_i": kernel_size, | 
|  | "order_s": "NHWC", | 
|  | "Y_scale_f": scale, | 
|  | "Y_zero_point_i": zero_point, | 
|  | } | 
|  | output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v', 'v', 'f', 'i') | 
|  | def add(g, input_a, input_b, scale, zero_point): | 
|  | kwargs = { | 
|  | "Y_scale_f": scale, | 
|  | "Y_zero_point_i": zero_point, | 
|  | } | 
|  | output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v') | 
|  | def relu(g, input): | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import relu | 
|  | return relu(g, input) | 
|  | kwargs = { | 
|  | "Y_scale_f": input.node()["Y_scale"], | 
|  | "Y_zero_point_i": input.node()["Y_zero_point"], | 
|  | } | 
|  | output = g.op("_caffe2::Int8Relu", input, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v', 'f', 'i', 't') | 
|  | def quantize_per_tensor(g, input, scale, zero_point, dtype): | 
|  | kwargs = { | 
|  | "Y_scale_f": scale, | 
|  | "Y_zero_point_i": zero_point, | 
|  | } | 
|  | output = g.op("_caffe2::Int8Quantize", input, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v') | 
|  | def dequantize(g, input): | 
|  | return g.op("_caffe2::Int8Dequantize", input) | 
|  |  | 
|  | @parse_args('v', 't', 't', 't', 't', 't', 't', 't') | 
|  | def _empty_affine_quantized(g, input, shape, scale, zero_point, dtype, pin_memory, memory_format, layout): | 
|  | return input | 
|  |  | 
|  | def upsample_nearest2d(g, input, output_size, align_corners=None, scales_h=None, scales_w=None): | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import upsample_nearest2d as upsample_nearest2d_impl | 
|  | return upsample_nearest2d_impl(g, input, output_size, align_corners) | 
|  |  | 
|  | output_size = sym_help._parse_arg(output_size, 'is') | 
|  | kwargs = { | 
|  | "output_size_i": output_size, | 
|  | "Y_scale_f": input.node()["Y_scale"], | 
|  | "Y_zero_point_i": input.node()["Y_zero_point"], | 
|  | } | 
|  | input = nchw2nhwc(g, input) | 
|  | output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) | 
|  | output = nhwc2nchw(g, output) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  | @parse_args('v', 'is', 'is', 'is', 'is', 'i') | 
|  | def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode): | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import max_pool2d | 
|  | return max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode) | 
|  | kwargs = { | 
|  | "strides_i": stride, | 
|  | "pads_i": padding + padding, | 
|  | "kernel_i": kernel_size[0], | 
|  | "order_s": "NHWC", | 
|  | "Y_scale_f": input.node()["Y_scale"], | 
|  | "Y_zero_point_i": input.node()["Y_zero_point"], | 
|  | } | 
|  | input = nchw2nhwc(g, input) | 
|  | output = g.op("_caffe2::Int8MaxPool", input, **kwargs) | 
|  | output = nhwc2nchw(g, output) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none') | 
|  | def avg_pool2d(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None): | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import avg_pool2d | 
|  | return avg_pool2d(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) | 
|  | kwargs = { | 
|  | "strides_i": stride, | 
|  | "pads_i": padding + padding, | 
|  | "kernel_i": kernel_size[0], | 
|  | "order_s": "NHWC", | 
|  | "Y_scale_f": input.node()["Y_scale"], | 
|  | "Y_zero_point_i": input.node()["Y_zero_point"], | 
|  | } | 
|  | input = nchw2nhwc(g, input) | 
|  | output = g.op("_caffe2::Int8AveragePool", input, **kwargs) | 
|  | output = nhwc2nchw(g, output) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | def reshape(g, input, shape): | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import reshape | 
|  | return reshape(g, input, shape) | 
|  |  | 
|  | kwargs = { | 
|  | "Y_scale_f": input.node()["Y_scale"], | 
|  | "Y_zero_point_i": input.node()["Y_zero_point"], | 
|  | } | 
|  | output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v', 'v', 'v', 'v', 'i') | 
|  | def slice(g, input, dim, start, end, step): | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import slice | 
|  | return slice(g, input, dim, start, end, step) | 
|  |  | 
|  | if step != 1: | 
|  | raise RuntimeError("ONNX quantized slice export only works for step 1.") | 
|  | start = sym_help._parse_arg(start, 'i') | 
|  | end = sym_help._parse_arg(end, 'i') | 
|  | dim = sym_help._parse_arg(dim, 'i') | 
|  |  | 
|  | kwargs = { | 
|  | "start_idx_i": start, | 
|  | "end_idx_i": end, | 
|  | "dim_i": dim, | 
|  | "Y_scale_f": input.node()["Y_scale"], | 
|  | "Y_zero_point_i": input.node()["Y_zero_point"], | 
|  | } | 
|  | output = g.op("_caffe2::Int8Slice", input, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | def cat(g, tensor_list, dim, scale=None, zero_point=None): | 
|  | tensors = sym_help._unpack_list(tensor_list) | 
|  | input = tensors[0] | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import cat | 
|  | return cat(g, tensor_list, dim) | 
|  |  | 
|  | dim = sym_help._parse_arg(dim, 'i') | 
|  | kwargs = { | 
|  | "Y_scale_f": tensors[0].node()["Y_scale"], | 
|  | "Y_zero_point_i": tensors[0].node()["Y_zero_point"], | 
|  | } | 
|  | output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output | 
|  |  | 
|  | @parse_args('v') | 
|  | def sigmoid(g, input): | 
|  | if input not in sym_help._quantized_ops: | 
|  | from torch.onnx.symbolic_opset9 import sigmoid | 
|  | return sigmoid(g, input) | 
|  | # Caffe2 expects the output scale to be 1/2^8 | 
|  | # and output zero_point to be 0 (quint8 type) | 
|  | out_scale = 1.0 / 256 | 
|  | zero_point = 0 | 
|  | kwargs = { | 
|  | "Y_scale_f": out_scale, | 
|  | "Y_zero_point_i": zero_point, | 
|  | } | 
|  | output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) | 
|  | sym_help._quantized_ops.add(output) | 
|  | return output |