| import importlib | 
 | import inspect | 
 |  | 
 | from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 | 
 | from torch.onnx._internal import jit_utils, registration | 
 |  | 
 |  | 
 | def register_quantized_ops(domain: str, version: int): | 
 |     # Register all quantized ops | 
 |     module = importlib.import_module("torch.onnx.symbolic_caffe2") | 
 |     quant_version_ops = inspect.getmembers(module) | 
 |     aten_q_ops = { | 
 |         "relu", | 
 |         "_empty_affine_quantized", | 
 |         "dequantize", | 
 |         "quantize_per_tensor", | 
 |         "upsample_nearest2d", | 
 |         "avg_pool2d", | 
 |         "reshape", | 
 |         "slice", | 
 |         "cat", | 
 |         "max_pool2d", | 
 |         "sigmoid", | 
 |     } | 
 |     for op, func in quant_version_ops: | 
 |         name = f"{domain}::{op}" | 
 |         if inspect.isfunction(func) and not registration.registry.is_registered_op( | 
 |             name, version | 
 |         ): | 
 |             if op in aten_q_ops: | 
 |                 # Override the builtin aten ops | 
 |                 registration.registry.register( | 
 |                     f"aten::{op}", version, func, custom=True | 
 |                 ) | 
 |             registration.registry.register(name, version, func) | 
 |  | 
 |  | 
 | def _permute_helper(g: jit_utils.GraphContext, input, axes): | 
 |     quant_args = { | 
 |         "axes_i": axes, | 
 |         "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | 
 |         "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | 
 |     } | 
 |     output = g.op("_caffe2::Int8Transpose", input, **quant_args) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | def nchw2nhwc(g: jit_utils.GraphContext, input): | 
 |     axes = [0, 2, 3, 1] | 
 |     return _permute_helper(g, input, axes) | 
 |  | 
 |  | 
 | def nhwc2nchw(g: jit_utils.GraphContext, input): | 
 |     axes = [0, 3, 1, 2] | 
 |     return _permute_helper(g, input, axes) | 
 |  | 
 |  | 
 | def linear_prepack(g: jit_utils.GraphContext, weight, bias): | 
 |     # Mapping to a dummy caffe2 prepack node. | 
 |     # During the onnx -> c2 conversion we can look up original weight and bias | 
 |     # from this node | 
 |     output = g.op("_caffe2::WeightPrepack", weight, bias) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "v", "v", "f", "i") | 
 | def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): | 
 |     kwargs = { | 
 |         "Y_scale_f": scale, | 
 |         "Y_zero_point_i": zero_point, | 
 |     } | 
 |     output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | def conv_prepack( | 
 |     g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups | 
 | ): | 
 |     # Mapping to a dummy caffe2 prepack node. | 
 |     # During the onnx -> c2 conversion we can look up original weight and bias | 
 |     # from this node | 
 |     output = g.op("_caffe2::WeightPrepack", input, weight, bias) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") | 
 | def conv2d( | 
 |     g: jit_utils.GraphContext, | 
 |     input, | 
 |     weight, | 
 |     bias, | 
 |     stride, | 
 |     padding, | 
 |     dilation, | 
 |     groups, | 
 |     scale, | 
 |     zero_point, | 
 | ): | 
 |     kernel_size = weight.node()["shape"][1:3] | 
 |     kwargs = { | 
 |         "strides_i": stride, | 
 |         "pads_i": padding + padding, | 
 |         "dilations_i": dilation, | 
 |         "group_i": groups, | 
 |         "kernels_i": kernel_size, | 
 |         "order_s": "NHWC", | 
 |         "Y_scale_f": scale, | 
 |         "Y_zero_point_i": zero_point, | 
 |     } | 
 |     output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") | 
 | def conv2d_relu( | 
 |     g: jit_utils.GraphContext, | 
 |     input, | 
 |     weight, | 
 |     bias, | 
 |     stride, | 
 |     padding, | 
 |     dilation, | 
 |     groups, | 
 |     scale, | 
 |     zero_point, | 
 | ): | 
 |     kernel_size = weight.node()["shape"][1:3] | 
 |     kwargs = { | 
 |         "strides_i": stride, | 
 |         "pads_i": padding + padding, | 
 |         "dilations_i": dilation, | 
 |         "group_i": groups, | 
 |         "kernels_i": kernel_size, | 
 |         "order_s": "NHWC", | 
 |         "Y_scale_f": scale, | 
 |         "Y_zero_point_i": zero_point, | 
 |     } | 
 |     output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "v", "f", "i") | 
 | def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): | 
 |     kwargs = { | 
 |         "Y_scale_f": scale, | 
 |         "Y_zero_point_i": zero_point, | 
 |     } | 
 |     output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v") | 
 | def relu(g: jit_utils.GraphContext, input): | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.relu(g, input) | 
 |     kwargs = { | 
 |         "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | 
 |         "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | 
 |     } | 
 |     output = g.op("_caffe2::Int8Relu", input, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "f", "i", "t") | 
 | def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): | 
 |     kwargs = { | 
 |         "Y_scale_f": scale, | 
 |         "Y_zero_point_i": zero_point, | 
 |     } | 
 |     output = g.op("_caffe2::Int8Quantize", input, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v") | 
 | def dequantize(g: jit_utils.GraphContext, input): | 
 |     return g.op("_caffe2::Int8Dequantize", input) | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") | 
 | def _empty_affine_quantized( | 
 |     g: jit_utils.GraphContext, | 
 |     input, | 
 |     shape, | 
 |     scale, | 
 |     zero_point, | 
 |     dtype, | 
 |     pin_memory, | 
 |     memory_format, | 
 |     layout, | 
 | ): | 
 |     return input | 
 |  | 
 |  | 
 | def upsample_nearest2d( | 
 |     g: jit_utils.GraphContext, | 
 |     input, | 
 |     output_size, | 
 |     align_corners=None, | 
 |     scales_h=None, | 
 |     scales_w=None, | 
 | ): | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.upsample_nearest2d(g, input, output_size, align_corners)  # type: ignore[attr-defined] | 
 |  | 
 |     output_size = symbolic_helper._parse_arg(output_size, "is") | 
 |     kwargs = { | 
 |         "output_size_i": output_size, | 
 |         "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | 
 |         "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | 
 |     } | 
 |     input = nchw2nhwc(g, input) | 
 |     output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) | 
 |     output = nhwc2nchw(g, output) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") | 
 | def max_pool2d( | 
 |     g: jit_utils.GraphContext, | 
 |     input, | 
 |     kernel_size, | 
 |     stride, | 
 |     padding, | 
 |     dilation, | 
 |     ceil_mode, | 
 | ): | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.max_pool2d(  # type: ignore[attr-defined] | 
 |             g, input, kernel_size, stride, padding, dilation, ceil_mode | 
 |         ) | 
 |     kwargs = { | 
 |         "strides_i": stride, | 
 |         "pads_i": padding + padding, | 
 |         "kernel_i": kernel_size[0], | 
 |         "order_s": "NHWC", | 
 |         "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | 
 |         "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | 
 |     } | 
 |     input = nchw2nhwc(g, input) | 
 |     output = g.op("_caffe2::Int8MaxPool", input, **kwargs) | 
 |     output = nhwc2nchw(g, output) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") | 
 | def avg_pool2d( | 
 |     g: jit_utils.GraphContext, | 
 |     input, | 
 |     kernel_size, | 
 |     stride, | 
 |     padding, | 
 |     ceil_mode, | 
 |     count_include_pad, | 
 |     divisor_override=None, | 
 | ): | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.avg_pool2d(  # type: ignore[attr-defined] | 
 |             g, | 
 |             input, | 
 |             kernel_size, | 
 |             stride, | 
 |             padding, | 
 |             ceil_mode, | 
 |             count_include_pad, | 
 |             divisor_override, | 
 |         ) | 
 |     kwargs = { | 
 |         "strides_i": stride, | 
 |         "pads_i": padding + padding, | 
 |         "kernel_i": kernel_size[0], | 
 |         "order_s": "NHWC", | 
 |         "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | 
 |         "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | 
 |     } | 
 |     input = nchw2nhwc(g, input) | 
 |     output = g.op("_caffe2::Int8AveragePool", input, **kwargs) | 
 |     output = nhwc2nchw(g, output) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | def reshape(g: jit_utils.GraphContext, input, shape): | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.reshape(g, input, shape) | 
 |  | 
 |     kwargs = { | 
 |         "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | 
 |         "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | 
 |     } | 
 |     output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v", "v", "v", "v", "i") | 
 | def slice(g: jit_utils.GraphContext, input, dim, start, end, step): | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.slice(g, input, dim, start, end, step) | 
 |  | 
 |     if step != 1: | 
 |         raise RuntimeError("ONNX quantized slice export only works for step 1.") | 
 |     start = symbolic_helper._parse_arg(start, "i") | 
 |     end = symbolic_helper._parse_arg(end, "i") | 
 |     dim = symbolic_helper._parse_arg(dim, "i") | 
 |  | 
 |     kwargs = { | 
 |         "start_idx_i": start, | 
 |         "end_idx_i": end, | 
 |         "dim_i": dim, | 
 |         "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), | 
 |         "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), | 
 |     } | 
 |     output = g.op("_caffe2::Int8Slice", input, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): | 
 |     tensors = symbolic_helper._unpack_list(tensor_list) | 
 |     input = tensors[0] | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.cat(g, tensor_list, dim) | 
 |  | 
 |     dim = symbolic_helper._parse_arg(dim, "i") | 
 |     kwargs = { | 
 |         "Y_scale_f": tensors[0].node()["Y_scale"], | 
 |         "Y_zero_point_i": tensors[0].node()["Y_zero_point"], | 
 |     } | 
 |     output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output | 
 |  | 
 |  | 
 | @symbolic_helper.parse_args("v") | 
 | def sigmoid(g: jit_utils.GraphContext, input): | 
 |     if input not in symbolic_helper._quantized_ops: | 
 |         return opset9.sigmoid(g, input) | 
 |     # Caffe2 expects the output scale to be 1/2^8 | 
 |     # and output zero_point to be 0 (quint8 type) | 
 |     out_scale = 1.0 / 256 | 
 |     zero_point = 0 | 
 |     kwargs = { | 
 |         "Y_scale_f": out_scale, | 
 |         "Y_zero_point_i": zero_point, | 
 |     } | 
 |     output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) | 
 |     symbolic_helper._quantized_ops.add(output) | 
 |     return output |