| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| from typing import Sequence |
| |
| import torch |
| from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY |
| from executorch.backends.qualcomm.quantizer.quantizer import ( |
| get_16a8w_qnn_ptq_config, |
| get_8a8w_qnn_ptq_config, |
| get_ptq_per_channel_quant_config, |
| QuantizationConfig, |
| ) |
| from executorch.exir.dialects._ops import ops as exir_ops |
| from torch.ao.quantization.observer import MinMaxObserver |
| from torch.ao.quantization.quantizer import ( |
| QuantizationAnnotation, |
| SharedQuantizationSpec, |
| ) |
| from torch.fx import Node |
| |
| |
| def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: |
| """ |
| This function is specific for matmul op 16a8w. |
| """ |
| |
| def annotate_matmul(node: Node, quantization_config: QuantizationConfig): |
| input_qspec_map = {} |
| input_act = node.args[0] |
| input_spec = quantization_config.input_activation |
| input_qspec_map[input_act] = input_spec |
| |
| input_act1 = node.args[1] |
| input_spec1 = quantization_config.weight |
| input_qspec_map[input_act1] = input_spec1 |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| def annotate_cat(node: Node, quantization_config: QuantizationConfig): |
| input_nodes = node.args[0] |
| |
| first_input_node = input_nodes[0] |
| input_qspec_map = {} |
| input_qspec_map[first_input_node] = quantization_config.input_activation |
| share_qparams_with_input_act0_qspec = SharedQuantizationSpec( |
| (first_input_node, node) |
| ) |
| |
| for input_node in input_nodes[1:]: |
| if input_node not in input_qspec_map: |
| input_qspec_map[input_node] = share_qparams_with_input_act0_qspec |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=share_qparams_with_input_act0_qspec, |
| _annotated=True, |
| ) |
| |
| def annotate_single_in_single_out( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| |
| input_qspec_map = {} |
| input_act = node.args[0] |
| input_qspec_map[input_act] = quantization_config.input_activation |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| def annotate_matmul_input1(node: Node): |
| quantization_config_8a8w = get_8a8w_qnn_ptq_config( |
| act_symmetric=True, act_observer=MinMaxObserver |
| ) |
| while isinstance(node, Node) and node.op == "call_function": |
| if node.target in [ |
| torch.ops.aten.permute.default, |
| torch.ops.aten.transpose.int, |
| ]: |
| annotate_single_in_single_out(node, quantization_config_8a8w) |
| node = node.args[0] |
| elif node.target == torch.ops.aten.cat.default: |
| annotate_cat(node, quantization_config_8a8w) |
| node = node.args[0][0] |
| else: |
| node = node.args[0] |
| |
| quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver) |
| |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: |
| annotate_matmul(node, quantization_config_16a8w) |
| annotate_matmul_input1(node.args[1]) |
| |
| |
| def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 |
| """ |
| This function is specific for llama matmul op 16a8w. |
| """ |
| |
| def annotate_matmul(node: Node, quantization_config: QuantizationConfig): |
| input_qspec_map = {} |
| input_act = node.args[0] |
| input_spec = quantization_config.input_activation |
| input_qspec_map[input_act] = input_spec |
| input_act1 = node.args[1] |
| input_spec1 = quantization_config.weight |
| input_qspec_map[input_act1] = input_spec1 |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: |
| input = node.args[0] |
| value = node.args[2] |
| input_qspec_map = {} |
| input_qspec_map[input] = quantization_config.input_activation |
| input_qspec_map[value] = SharedQuantizationSpec((input, node)) |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=SharedQuantizationSpec((input, node)), |
| _annotated=True, |
| ) |
| |
| def annotate_single_in_single_out( |
| node: Node, quantization_config: QuantizationConfig |
| ) -> None: |
| input_qspec_map = {} |
| input_act = node.args[0] |
| input_qspec_map[input_act] = quantization_config.input_activation |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| def annotate_cat(node: Node, quantization_config: QuantizationConfig): |
| input_nodes = node.args[0] |
| assert isinstance(input_nodes, Sequence) |
| first_input_node = input_nodes[0] |
| input_qspec_map = {} |
| assert isinstance(first_input_node, Node) |
| assert isinstance(node, Node) |
| input_qspec_map[first_input_node] = quantization_config.input_activation |
| share_qparams_with_input_act0_qspec = SharedQuantizationSpec( |
| (first_input_node, node) |
| ) |
| for input_node in input_nodes[1:]: |
| if input_node not in input_qspec_map: |
| assert isinstance(input_node, Node) |
| input_qspec_map[input_node] = share_qparams_with_input_act0_qspec |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=share_qparams_with_input_act0_qspec, |
| _annotated=True, |
| ) |
| |
| def is_edge_condition(node: Node): |
| if not isinstance(node, Node) or node.op != "call_function": |
| return True |
| return False |
| |
| def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): |
| if is_edge_condition(node): |
| return |
| if node.target in [ |
| torch.ops.aten.index_put.default, |
| torch.ops.aten.index_put_.default, |
| ]: |
| annotate_index_put(node, quantization_config) |
| annotate_matmul_input1(node.args[0], quantization_config) |
| elif node.target == torch.ops.aten.cat.default: |
| annotate_cat(node, quantization_config) |
| # Expect that the inputs of the cat op are select ops |
| for arg in node.args[0]: |
| annotate_matmul_input1(arg, quantization_config) |
| else: |
| annotate_single_in_single_out(node, quantization_config) |
| annotate_matmul_input1(node.args[0], quantization_config) |
| |
| # Annotate 16a8w for matmul op to get better performance |
| quantization_config_16a8w = get_16a8w_qnn_ptq_config() |
| # Annotate 8a8w for second input of matmul until past_kv_cache |
| quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: |
| if "nn_module_stack" in node.meta: |
| module_values_list = list(node.meta["nn_module_stack"].values()) |
| full_qualified_name = module_values_list[-1][0] |
| if "SDPA" in full_qualified_name: |
| annotate_matmul(node, quantization_config_16a8w) |
| annotate_matmul_input1(node.args[1], quantization_config_8a8w) |
| |
| |
| def custom_annotate_llama_last_conv_16a8w(gm: torch.fx.GraphModule) -> None: |
| def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: |
| input_qspec_map = {} |
| input_act = node.args[0] |
| input_spec = quantization_config.input_activation |
| input_qspec_map[input_act] = input_spec |
| |
| weight = node.args[1] |
| input_qspec_map[weight] = quantization_config.weight |
| |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( |
| torch.uint16, weight_dtype=torch.int8 |
| ) |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: |
| if "nn_module_stack" in node.meta: |
| module_values_list = list(node.meta["nn_module_stack"].values()) |
| full_qualified_name = module_values_list[0][0] |
| if full_qualified_name == "L['self'].llama.output": |
| annotate_conv2d( |
| node, quantization_config=quantization_config_16a8w_per_channel |
| ) |
| |
| |
| def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule): |
| """ |
| Annotate matmul op with 16a8w quantization config |
| """ |
| |
| def annotate_matmul(node: Node, quantization_config: QuantizationConfig): |
| input_qspec_map = {} |
| input_act = node.args[0] |
| input_spec = quantization_config.input_activation |
| input_qspec_map[input_act] = input_spec |
| input_act1 = node.args[1] |
| input_spec1 = quantization_config.weight |
| input_qspec_map[input_act1] = input_spec1 |
| node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| input_qspec_map=input_qspec_map, |
| output_qspec=quantization_config.output_activation, |
| _annotated=True, |
| ) |
| |
| # Annotate 16a8w for matmul op to get better performance |
| quantization_config_16a8w = get_16a8w_qnn_ptq_config() |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: |
| annotate_matmul(node, quantization_config_16a8w) |
| |
| |
| def get_custom_quant_ios_dtype( |
| cache_shape: torch.Size, |
| node: torch.fx.Node, |
| kv_dtype=torch.uint8, |
| sharding_dtype=torch.uint16, |
| ): |
| """ |
| This function is specific for llama inputs and outputs |
| """ |
| if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name: |
| return kv_dtype |
| |
| # Tag index put node before copy node, because copy is a skipped node in qnn |
| if ( |
| exir_ops.edge.aten.index_put.default == node.target |
| and node.meta["val"].shape == cache_shape |
| ): |
| return kv_dtype |
| |
| # Tag sharding io |
| if exir_ops.edge.llama.fallback.default in [ |
| u.target for u in list(node.users.keys()) |
| ] + [node.target]: |
| return sharding_dtype |
| |
| # Tag index op as quantized tensors. It is caused by sharding |
| if exir_ops.edge.aten.index.Tensor in [ |
| u.target for u in list(node.users.keys()) |
| ] + [node.target]: |
| return sharding_dtype |