| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import numbers |
| import operator |
| from functools import partial |
| from typing import Callable, Dict, List, Sequence, Tuple |
| |
| import torch |
| from torch._ops import OpOverload |
| |
| from torch._subclasses import FakeTensor |
| from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize |
| |
| from torch.ao.quantization.observer import FixedQParamsObserver |
| from torch.ao.quantization.quantizer import ( |
| DerivedQuantizationSpec, |
| QuantizationAnnotation, |
| QuantizationSpec, |
| SharedQuantizationSpec, |
| ) |
| from torch.ao.quantization.quantizer.utils import ( |
| _annotate_input_qspec_map, |
| _annotate_output_qspec, |
| ) |
| from torch.fx import Node |
| |
| from .qconfig import ( |
| get_16a16w_qnn_ptq_config, |
| get_16a4w_qnn_qat_config, |
| get_8a8w_qnn_qat_config, |
| QuantizationConfig, |
| ) |
| |
| |
| QUANT_ANNOTATION_KEY = "quantization_annotation" |
| OP_ANNOTATOR: Dict[OpOverload, Callable] = {} |
| |
| |
| def register_annotator(ops: List[OpOverload]): |
| def decorator(annotator: Callable): |
| for op in ops: |
| OP_ANNOTATOR[op] = annotator |
| |
| return decorator |
| |
| |
| def _is_annotated(nodes: List[Node]): |
| """ |
| Given a list of nodes (that represents an operator pattern), |
| return True if any of the node |
| is annotated, otherwise return False |
| """ |
| annotated = False |
| for node in nodes: |
| annotated = annotated or ( |
| QUANT_ANNOTATION_KEY in node.meta |
| and node.meta[QUANT_ANNOTATION_KEY]._annotated |
| ) |
| return annotated |
| |
| |
| def _is_float_tensor(node: Node): |
| """Check if the node's tensor is a float tensor, so that we can skip quantization for the node |
| since observers only works with float Tensors |
| """ |
| if ( |
| not isinstance(node, Node) |
| or "val" not in node.meta |
| or not isinstance(node.meta["val"], FakeTensor) |
| ): |
| return False |
| return node.meta["val"].dtype == torch.float32 |
| |
| |
| def _mark_nodes_as_annotated(nodes: List[Node]): |
| for node in nodes: |
| if QUANT_ANNOTATION_KEY not in node.meta: |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation() |
| node.meta[QUANT_ANNOTATION_KEY]._annotated = True |
| |
| |
| def annotate_in_out_obs_sharing_op( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| |
| # only annotate input output sharing operator |
| # when the output of the input node is annotated |
| if ( |
| QUANT_ANNOTATION_KEY not in input_act.meta |
| or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated |
| or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None |
| ): |
| return |
| |
| act_qspec = SharedQuantizationSpec(input_act) |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map={ |
| input_act: act_qspec, |
| }, |
| output_qspec=act_qspec, |
| _annotated=True, |
| ) |
| |
| |
| def annotate_single_in_single_out( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_qspec_map = {} |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| input_qspec_map[input_act] = quantization_config.input_activation |
| |
| if _is_float_tensor(node): |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.topk.default]) |
| def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| # We can use single_in_single_out since we don't want to quantize indices output |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_act_qspec = quantization_config.input_activation |
| output_act_qspec = ( |
| quantization_config.output_activation if _is_float_tensor(node) else None |
| ) |
| |
| input_qspec_map = {} |
| input_act0 = node.args[0] |
| if _is_float_tensor(input_act0): |
| input_qspec_map[input_act0] = input_act_qspec |
| |
| input_act1 = node.args[1] |
| if _is_float_tensor(input_act1): |
| input_qspec_map[input_act1] = input_act_qspec |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=output_act_qspec, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor]) |
| def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_binary(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor]) |
| def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_binary(node, quantization_config) |
| |
| |
| @register_annotator( |
| [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar] |
| ) |
| def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_binary(node, quantization_config) |
| |
| |
| @register_annotator( |
| [torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor] |
| ) |
| def annotate_div(node: Node, quantization_config: QuantizationConfig) -> None: |
| def _derived_inp1_const_div_quant_spec( |
| node: torch.fx.Node, output_qspec: QuantizationSpec |
| ) -> DerivedQuantizationSpec: |
| def _derive_div_qparams_fn( |
| obs_or_fqs: List, |
| const_val: float, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| inp_0_obs_or_fq = obs_or_fqs[0] |
| inp_0_scale, inp_0_zp = inp_0_obs_or_fq.calculate_qparams() |
| derived_scale = inp_0_scale / const_val |
| return (derived_scale, inp_0_zp) |
| |
| inp_0 = node.args[0] |
| const_inp_1 = node.args[1] |
| _derive_div_qparams_with_const_fn = partial( |
| _derive_div_qparams_fn, const_val=const_inp_1 |
| ) |
| |
| q_min = ( |
| torch.iinfo(output_qspec.dtype).min |
| if output_qspec.quant_min is None |
| else output_qspec.quant_min |
| ) |
| q_max = ( |
| torch.iinfo(output_qspec.dtype).max |
| if output_qspec.quant_max is None |
| else output_qspec.quant_max |
| ) |
| return DerivedQuantizationSpec( |
| derived_from=[(inp_0, node)], |
| derive_qparams_fn=_derive_div_qparams_with_const_fn, |
| dtype=output_qspec.dtype, |
| quant_min=q_min, |
| quant_max=q_max, |
| ch_axis=0, |
| qscheme=output_qspec.qscheme, |
| ) |
| |
| if [a for a in node.args if isinstance(a, Node)]: |
| annotate_binary(node, quantization_config) |
| # special constant divisor case |
| elif isinstance(node.args[0], Node) and isinstance(node.args[1], numbers.Number): |
| if _is_annotated([node]): |
| return |
| |
| input_act_qspec = quantization_config.input_activation |
| output_act_qspec = _derived_inp1_const_div_quant_spec( |
| node, quantization_config.output_activation |
| ) |
| input_qspec_map = {} |
| input_act0 = node.args[0] |
| if _is_float_tensor(input_act0): |
| input_qspec_map[input_act0] = input_act_qspec |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=output_act_qspec, |
| _annotated=True, |
| ) |
| else: |
| raise NotImplementedError(f"No quant annotation is implemented for {node}.") |
| |
| |
| @register_annotator([torch.ops.aten.rsub.Scalar]) |
| def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_binary(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.sum.dim_IntList]) |
| def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_binary(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.ceil.default]) |
| def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.clamp.default]) |
| def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default]) |
| def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.tanh.default]) |
| def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator( |
| [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default] |
| ) |
| def annotate_hardswish(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator( |
| [torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardsigmoid_.default] |
| ) |
| def annotate_hardsigmoid(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default]) |
| def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.mean.default]) |
| def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.max_pool2d.default]) |
| def annotate_max_pool2d(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.max_pool2d_with_indices.default]) |
| def annotate_max_pool2d_with_indices( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.adaptive_avg_pool2d.default]) |
| def annotate_adaptive_avgpool2d( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.avg_pool2d.default]) |
| def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.permute.default]) |
| def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator( |
| [ |
| torch.ops.aten.leaky_relu.default, |
| torch.ops.aten.leaky_relu_.default, |
| torch.ops.aten.prelu.default, |
| ] |
| ) |
| def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default]) |
| def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.pixel_shuffle.default]) |
| def annotate_pixel_shuffle_default( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.pixel_unshuffle.default]) |
| def annotate_pixel_unshuffle_default( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.upsample_bilinear2d.vec]) |
| def annotate_upsample_bilinear2d( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.upsample_nearest2d.vec]) |
| def annotate_upsample_nearest2d( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator( |
| [ |
| torch.ops.aten.softmax.int, |
| torch.ops.aten._softmax.default, |
| torch.ops.aten._safe_softmax.default, |
| ] |
| ) |
| def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.log_softmax.int]) |
| def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.pad.default]) |
| def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.reshape.default]) |
| def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.select.int]) |
| def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.mean.dim]) |
| def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.slice.Tensor]) |
| def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.sqrt.default]) |
| def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.gelu.default]) |
| def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.scaled_dot_product_attention.default]) |
| def annotate_scaled_dot_product_attention( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator( |
| [ |
| torch.ops.aten.squeeze.default, |
| torch.ops.aten.squeeze.dim, |
| torch.ops.aten.squeeze_copy.dims, |
| ] |
| ) |
| def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.rms_norm.default]) |
| def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None: |
| act_node = node.args[0] |
| weight_node = node.args[2] |
| |
| if _is_annotated([node]): |
| return |
| |
| # TODO current only support 16a16w |
| _annotate_input_qspec_map( |
| node, |
| act_node, |
| quantization_config.input_activation, |
| ) |
| |
| _annotate_input_qspec_map( |
| node, |
| weight_node, |
| quantization_config.input_activation, |
| ) |
| nodes_to_mark_annotated = [node] |
| _annotate_output_qspec(node, quantization_config.output_activation) |
| _mark_nodes_as_annotated(nodes_to_mark_annotated) |
| |
| |
| @register_annotator([torch.ops.aten.rsqrt.default]) |
| def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default]) |
| def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_qspec_map = {} |
| input_act = node.args[0] |
| input_qspec_map[input_act] = quantization_config.input_activation |
| |
| assert isinstance(input_act, Node) |
| out_qconf = quantization_config.output_activation |
| |
| q_max = ( |
| torch.iinfo(out_qconf.dtype).max |
| if out_qconf.quant_max is None |
| else out_qconf.quant_max |
| ) |
| q_min = ( |
| torch.iinfo(out_qconf.dtype).min |
| if out_qconf.quant_min is None |
| else out_qconf.quant_min |
| ) |
| |
| scale = 1 / (q_max - q_min + 1) |
| |
| bias_obs_ctr = observer = FixedQParamsObserver.with_args( |
| scale=scale, |
| zero_point=0, |
| dtype=quantization_config.output_activation.dtype, |
| qscheme=torch.torch.per_tensor_affine, |
| quant_max=q_max, |
| quant_min=q_min, |
| ) |
| if quantization_config in ( |
| get_8a8w_qnn_qat_config(), |
| get_16a4w_qnn_qat_config(), |
| ): |
| bias_obs_ctr = FixedQParamsFakeQuantize.with_args( |
| observer=observer, |
| scale=scale, |
| zero_point=0, |
| dtype=quantization_config.output_activation.dtype, |
| qscheme=torch.torch.per_tensor_affine, |
| quant_max=q_max, |
| quant_min=q_min, |
| ) |
| |
| # make sigmoid map to the range between 0~1 |
| out_act_quantization_spec = QuantizationSpec( |
| dtype=quantization_config.output_activation.dtype, |
| quant_max=q_max, |
| quant_min=q_min, |
| observer_or_fake_quant_ctr=bias_obs_ctr, |
| qscheme=torch.torch.per_tensor_affine, |
| ) |
| |
| if _is_float_tensor(node): |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=out_act_quantization_spec, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.pow.Tensor_Scalar]) |
| def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.unsqueeze.default]) |
| def annotate_unsqueeze(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator( |
| [ |
| torch.ops.aten.unsqueeze_copy.default, |
| ] |
| ) |
| def annotate_unsqueeze_copy( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.transpose.int]) |
| def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.embedding.default]) |
| def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None: |
| weight = node.args[0] |
| |
| input_qspec_map = {} |
| input_qspec_map[weight] = quantization_config.input_activation |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=SharedQuantizationSpec((weight, node)), |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.index.Tensor]) |
| def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| input_qspec_map = {} |
| input = node.args[0] |
| input_qspec_map[input] = quantization_config.input_activation |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=SharedQuantizationSpec((input, node)), |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator( |
| [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] |
| ) |
| def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: |
| input = node.args[0] |
| value = node.args[2] |
| |
| input_qspec_map = {} |
| input_qspec_map[input] = quantization_config.input_activation |
| input_qspec_map[value] = SharedQuantizationSpec((input, node)) |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=SharedQuantizationSpec((input, node)), |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.expand.default]) |
| def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.group_norm.default]) |
| def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> None: |
| act_node = node.args[0] |
| weight_node = node.args[2] |
| bias_node = None |
| if len(node.args) > 2: |
| bias_node = node.args[3] |
| |
| if _is_annotated([node]): |
| return |
| |
| _annotate_input_qspec_map( |
| node, |
| act_node, |
| quantization_config.input_activation, |
| ) |
| _annotate_input_qspec_map( |
| node, |
| weight_node, |
| quantization_config.weight, |
| ) |
| nodes_to_mark_annotated = [node, weight_node] |
| if bias_node: |
| _annotate_input_qspec_map( |
| node, |
| bias_node, |
| quantization_config.bias, |
| ) |
| nodes_to_mark_annotated.append(bias_node) |
| _annotate_output_qspec(node, quantization_config.output_activation) |
| _mark_nodes_as_annotated(nodes_to_mark_annotated) |
| |
| |
| @register_annotator([torch.ops.aten.flatten.using_ints]) |
| def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None: |
| annotate_in_out_obs_sharing_op(node, quantization_config) |
| if not _is_annotated([node]): |
| annotate_single_in_single_out(node, quantization_config) |
| |
| |
| @register_annotator([torch.ops.aten.stack.default]) |
| def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None: |
| input_qspec_map = {} |
| for input_act in node.args[0]: |
| assert isinstance(input_act, Node) |
| input_qspec_map[input_act] = quantization_config.input_activation |
| |
| node_tensor = node.meta.get("val") |
| if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64: |
| continue |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.matmul.default]) |
| def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_act_qspec = quantization_config.input_activation |
| output_act_qspec = quantization_config.output_activation |
| |
| input_qspec_map = {} |
| input_act0 = node.args[0] |
| if isinstance(input_act0, Node): |
| input_qspec_map[input_act0] = input_act_qspec |
| |
| input_act1 = node.args[1] |
| if isinstance(input_act1, Node): |
| # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. |
| if input_act_qspec.dtype == torch.int32: |
| # we should use int16 for mm / bmm instead of int4 |
| input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight |
| else: |
| input_qspec_map[input_act1] = input_act_qspec |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=output_act_qspec, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.bmm.default]) |
| def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_act_qspec = quantization_config.input_activation |
| output_act_qspec = quantization_config.output_activation |
| |
| input_qspec_map = {} |
| input_act0 = node.args[0] |
| if isinstance(input_act0, Node): |
| input_qspec_map[input_act0] = input_act_qspec |
| |
| input_act1 = node.args[1] |
| if isinstance(input_act1, Node): |
| # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. |
| if input_act_qspec.dtype == torch.int32: |
| # we should use int16 for mm / bmm instead of int4 |
| input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight |
| else: |
| input_qspec_map[input_act1] = input_act_qspec |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=output_act_qspec, |
| _annotated=True, |
| ) |
| |
| # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. |
| node.meta["source_fn_stack"] = [(node, torch.bmm)] |
| |
| |
| @register_annotator( |
| [ |
| torch.ops.aten.conv2d.default, |
| torch.ops.aten.conv1d.default, |
| torch.ops.aten.conv_transpose2d.input, |
| ] |
| ) |
| def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_qspec_map = {} |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| input_spec = quantization_config.input_activation |
| input_qspec_map[input_act] = input_spec |
| |
| weight = node.args[1] |
| assert isinstance(weight, Node) |
| input_qspec_map[weight] = quantization_config.weight |
| |
| if len(node.args) > 2: |
| bias = node.args[2] |
| if isinstance(bias, Node): |
| if callable(quantization_config.bias): |
| input_qspec_map[bias] = quantization_config.bias(node) |
| else: |
| input_qspec_map[bias] = quantization_config.bias |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.linear.default]) |
| def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None: |
| act_node = node.args[0] |
| weight_node = node.args[1] |
| bias_node = None |
| if len(node.args) > 2: |
| bias_node = node.args[2] |
| |
| if _is_annotated([node]): |
| return |
| |
| _annotate_input_qspec_map( |
| node, |
| act_node, |
| quantization_config.input_activation, |
| ) |
| _annotate_input_qspec_map( |
| node, |
| weight_node, |
| quantization_config.weight, |
| ) |
| nodes_to_mark_annotated = [node, weight_node] |
| if bias_node: |
| if callable(quantization_config.bias): |
| bias_config = quantization_config.bias(node) |
| else: |
| bias_config = quantization_config.bias |
| _annotate_input_qspec_map(node, bias_node, bias_config) |
| nodes_to_mark_annotated.append(bias_node) |
| _annotate_output_qspec(node, quantization_config.output_activation) |
| _mark_nodes_as_annotated(nodes_to_mark_annotated) |
| |
| # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. |
| node.meta["source_fn_stack"] = [(node, torch.nn.Linear)] |
| |
| |
| @register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default]) |
| def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None: |
| act, weight, bias = node.args[0:3] |
| if _is_annotated([node]): |
| return |
| |
| _annotate_input_qspec_map( |
| node, |
| act, |
| quantization_config.input_activation, |
| ) |
| # QNN requires uint8 instead of int8 in 'weight' config |
| _annotate_input_qspec_map( |
| node, |
| weight, |
| quantization_config.input_activation, |
| ) |
| _annotate_input_qspec_map( |
| node, |
| bias, |
| quantization_config.bias, |
| ) |
| _annotate_output_qspec(node, quantization_config.output_activation) |
| _mark_nodes_as_annotated([node, *node.args[0:3]]) |
| |
| |
| @register_annotator([operator.getitem]) |
| def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| if _is_float_tensor(node): |
| _annotate_output_qspec(node, quantization_config.output_activation) |
| _mark_nodes_as_annotated([node]) |
| |
| |
| @register_annotator([torch.ops.aten.layer_norm.default]) |
| def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None: |
| act_node = node.args[0] |
| weight_node = node.args[2] |
| bias_node = None |
| if len(node.args) > 2: |
| bias_node = node.args[3] |
| |
| if _is_annotated([node]): |
| return |
| input_act_qspec = quantization_config.input_activation |
| |
| _annotate_input_qspec_map( |
| node, |
| act_node, |
| input_act_qspec, |
| ) |
| if input_act_qspec.dtype == torch.int32: |
| _annotate_input_qspec_map( |
| node, |
| weight_node, |
| get_16a16w_qnn_ptq_config().weight, |
| ) |
| else: |
| _annotate_input_qspec_map( |
| node, |
| weight_node, |
| input_act_qspec, |
| ) |
| nodes_to_mark_annotated = [node, weight_node] |
| if bias_node: |
| _annotate_input_qspec_map( |
| node, |
| bias_node, |
| quantization_config.bias, |
| ) |
| nodes_to_mark_annotated.append(bias_node) |
| _annotate_output_qspec(node, quantization_config.output_activation) |
| _mark_nodes_as_annotated(nodes_to_mark_annotated) |
| |
| |
| @register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default]) |
| def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None: |
| input_nodes = node.args[0] |
| if _is_annotated([node]): |
| return |
| |
| assert isinstance(input_nodes, Sequence) |
| |
| first_input_node = input_nodes[0] |
| input_qspec_map = {} |
| assert isinstance(first_input_node, Node) |
| assert isinstance(node, Node) |
| input_qspec_map[first_input_node] = quantization_config.input_activation |
| share_qparams_with_input_act0_qspec = SharedQuantizationSpec( |
| (first_input_node, node) |
| ) |
| |
| for input_node in input_nodes[1:]: |
| if input_node not in input_qspec_map: |
| assert isinstance(input_node, Node) |
| input_qspec_map[input_node] = share_qparams_with_input_act0_qspec |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=share_qparams_with_input_act0_qspec, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.unbind.int]) |
| def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_qspec_map = {} |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| input_qspec_map[input_act] = quantization_config.input_activation |
| |
| node_tensor = node.meta.get("val") |
| if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64: |
| return |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| _annotated=True, |
| ) |
| |
| |
| @register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default]) |
| def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: |
| if _is_annotated([node]): |
| return |
| |
| input_qspec_map = {} |
| input_act = node.args[0] |
| assert isinstance(input_act, Node) |
| input_qspec_map[input_act] = quantization_config.input_activation |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| _annotated=True, |
| ) |
| |
| for user in node.users: |
| user.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |