| # Copyright (c) 2024 MediaTek Inc. |
| # |
| # Licensed under the BSD License (the "License"); you may not use this file |
| # except in compliance with the License. See the license file in the root |
| # directory of this source tree for more details. |
| |
| from typing import Callable, List |
| |
| import torch |
| from torch._ops import OpOverload |
| from torch._subclasses import FakeTensor |
| |
| from torch.ao.quantization.quantizer import QuantizationAnnotation |
| from torch.ao.quantization.quantizer.utils import ( |
| _annotate_input_qspec_map, |
| _annotate_output_qspec, |
| ) |
| |
| from torch.export import export_for_training |
| from torch.fx import Graph, Node |
| from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( |
| SubgraphMatcherWithNameNodeMap, |
| ) |
| |
| from .qconfig import QuantizationConfig |
| |
| |
| OP_TO_ANNOTATOR = {} |
| |
| |
| def annotate(graph: Graph, quant_config: QuantizationConfig) -> None: |
| # Pattern annotation |
| _annotate_rmsnorm_pattern(graph, quant_config) |
| _annotate_fused_activation_pattern(graph, quant_config) |
| |
| # Per-op annotation |
| for node in graph.nodes: |
| if node.op == "placeholder": |
| annotate_placeholder(node, quant_config) |
| elif node.op == "call_function": |
| annotate_func = OP_TO_ANNOTATOR.get(node.target, None) |
| if annotate_func is not None: |
| annotate_func(node, quant_config) |
| |
| |
| def register_annotator(ops: List[OpOverload]): |
| |
| def decorator(annotator_fn: Callable): |
| for op in ops: |
| OP_TO_ANNOTATOR[op] = annotator_fn |
| |
| return decorator |
| |
| |
| def _is_annotated(node: Node): |
| """ |
| Given a list of nodes (that represents an operator pattern), |
| return True if any of the node |
| is annotated, otherwise return False |
| """ |
| KEY = "quantization_annotation" |
| return KEY in node.meta and node.meta[KEY]._annotated |
| |
| |
| def _mark_as_annotated(nodes: List[Node]): |
| KEY = "quantization_annotation" |
| for node in nodes: |
| if KEY not in node.meta: |
| node.meta[KEY] = QuantizationAnnotation() |
| node.meta[KEY]._annotated = True |
| |
| |
| def _is_float_activation_tensor(node: Node): |
| if not isinstance(node, Node): |
| return False |
| if "val" not in node.meta: |
| return False |
| if not isinstance(node.meta["val"], FakeTensor): |
| return False |
| return node.meta["val"].dtype == torch.float32 |
| |
| |
| def _annotate_fused_activation_pattern( |
| graph: Graph, quant_config: QuantizationConfig |
| ) -> None: |
| for relu_node in graph.nodes: |
| # Check relu/relu6 node |
| if relu_node.op != "call_function": |
| continue |
| if relu_node.target not in [ |
| torch.ops.aten.relu.default, |
| torch.ops.aten.relu_.default, |
| torch.ops.aten.relu6.default, |
| ]: |
| continue |
| |
| producer_node = relu_node.args[0] |
| if not isinstance(producer_node, Node): |
| continue |
| if producer_node.op != "call_function": |
| continue |
| if len(producer_node.users) != 1: |
| continue |
| |
| # Handle affine + relu fusion |
| if producer_node.target in [ |
| torch.ops.aten.conv1d.default, |
| torch.ops.aten.conv2d.default, |
| torch.ops.aten.linear.default, |
| ]: |
| weight_node = producer_node.args[1] |
| _annotate_input_qspec_map( |
| producer_node, |
| weight_node, |
| quant_config.weight, |
| ) |
| _annotate_output_qspec(relu_node, quant_config.activation) |
| _mark_as_annotated([producer_node, weight_node, relu_node]) |
| continue |
| |
| # Handle arithmetic + relu fusion |
| if producer_node.target in [ |
| torch.ops.aten.add.Scalar, |
| torch.ops.aten.add.Tensor, |
| torch.ops.aten.add_.Scalar, |
| torch.ops.aten.add_.Tensor, |
| torch.ops.aten.div.Scalar, |
| torch.ops.aten.div.Tensor, |
| torch.ops.aten.div_.Scalar, |
| torch.ops.aten.div_.Tensor, |
| torch.ops.aten.divide.Scalar, |
| torch.ops.aten.divide.Tensor, |
| torch.ops.aten.mul.Scalar, |
| torch.ops.aten.mul.Tensor, |
| torch.ops.aten.mul_.Scalar, |
| torch.ops.aten.mul_.Tensor, |
| torch.ops.aten.rsub.Scalar, |
| torch.ops.aten.rsub.Tensor, |
| torch.ops.aten.sub.Scalar, |
| torch.ops.aten.sub.Tensor, |
| torch.ops.aten.sub_.Scalar, |
| torch.ops.aten.sub_.Tensor, |
| ]: |
| _annotate_output_qspec(relu_node, quant_config.activation) |
| _mark_as_annotated([producer_node, relu_node]) |
| continue |
| |
| |
| def _annotate_rmsnorm_pattern(graph: Graph, quant_config: QuantizationConfig) -> None: |
| |
| class ExecuTorchPattern(torch.nn.Module): |
| def forward(self, x): |
| norm = x * torch.rsqrt((x * x).mean(-1, keepdim=True) + 1e-6) |
| return norm, {} |
| |
| class MTKPattern(torch.nn.Module): |
| def forward(self, x): |
| norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) |
| return norm, {} |
| |
| for pattern_cls in (ExecuTorchPattern, MTKPattern): |
| pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module() |
| matcher = SubgraphMatcherWithNameNodeMap( |
| pattern_gm, ignore_literals=True, remove_overlapping_matches=False |
| ) |
| matches = matcher.match(graph) |
| for match in matches: |
| target_nodes = [] |
| for node in match.nodes_map.values(): |
| if node in match.placeholder_nodes: |
| continue |
| if node.op == "call_function" and node.target in OP_TO_ANNOTATOR: |
| target_nodes.append(node) |
| |
| if any(_is_annotated(node) for node in target_nodes): |
| continue |
| _mark_as_annotated(target_nodes) |
| for node in match.returning_nodes: |
| _annotate_output_qspec(node, quant_config.activation) |
| |
| |
| def annotate_placeholder(node: Node, quant_config: QuantizationConfig) -> None: |
| if _is_annotated(node): |
| return |
| |
| if _is_float_activation_tensor(node): |
| _annotate_output_qspec(node, quant_config.activation) |
| |
| _mark_as_annotated([node]) |
| |
| |
| @register_annotator( |
| [ |
| torch.ops.aten.conv1d.default, |
| torch.ops.aten.conv2d.default, |
| torch.ops.aten.linear.default, |
| ] |
| ) |
| def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None: |
| if _is_annotated(node): |
| return |
| |
| weight_node = node.args[1] |
| _annotate_input_qspec_map( |
| node, |
| weight_node, |
| quant_config.weight, |
| ) |
| _annotate_output_qspec(node, quant_config.activation) |
| |
| # Make weight as annotated because it is a constant node |
| _mark_as_annotated([node, weight_node]) |
| |
| |
| @register_annotator( |
| [ |
| torch.ops.aten.add.Scalar, |
| torch.ops.aten.add.Tensor, |
| torch.ops.aten.add_.Scalar, |
| torch.ops.aten.add_.Tensor, |
| torch.ops.aten.bmm.default, |
| torch.ops.aten.div.Scalar, |
| torch.ops.aten.div.Tensor, |
| torch.ops.aten.div_.Scalar, |
| torch.ops.aten.div_.Tensor, |
| torch.ops.aten.divide.Scalar, |
| torch.ops.aten.divide.Tensor, |
| torch.ops.aten.gelu.default, |
| torch.ops.aten.group_norm.default, |
| torch.ops.aten.layer_norm.default, |
| torch.ops.aten.leaky_relu.default, |
| torch.ops.aten.matmul.default, |
| torch.ops.aten.mul.Scalar, |
| torch.ops.aten.mul.Tensor, |
| torch.ops.aten.mul_.Scalar, |
| torch.ops.aten.mul_.Tensor, |
| torch.ops.aten.pow.Scalar, |
| torch.ops.aten.pow.Tensor_Scalar, |
| torch.ops.aten.pow.Tensor_Tensor, |
| torch.ops.aten.prelu.default, |
| torch.ops.aten.rsub.Scalar, |
| torch.ops.aten.rsub.Tensor, |
| torch.ops.aten.silu.default, |
| torch.ops.aten.sub.Scalar, |
| torch.ops.aten.sub.Tensor, |
| torch.ops.aten.sub_.Scalar, |
| torch.ops.aten.sub_.Tensor, |
| ] |
| ) |
| def annotate_output_qspec(node: Node, quant_config: QuantizationConfig) -> None: |
| if _is_annotated(node): |
| return |
| _annotate_output_qspec(node, quant_config.activation) |
| _mark_as_annotated([node]) |
| |
| |
| @register_annotator([torch.ops.aten.embedding.default]) |
| def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None: |
| if _is_annotated(node): |
| return |
| |
| wgt_node = node.args[0] |
| _annotate_input_qspec_map(node, wgt_node, quant_config.activation) |
| _mark_as_annotated([node]) |