blob: 8a3ff4057113e70e6eefe892b8524f2934e99d16 [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Sequence
import torch
from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
from executorch.backends.qualcomm.quantizer.quantizer import (
get_16a8w_qnn_ptq_config,
get_8a8w_qnn_ptq_config,
get_ptq_per_channel_quant_config,
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)
from torch.fx import Node
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
"""
This function is specific for matmul op 16a8w.
"""
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec
input_act1 = node.args[1]
input_spec1 = quantization_config.weight
input_qspec_map[input_act1] = input_spec1
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
def annotate_cat(node: Node, quantization_config: QuantizationConfig):
input_nodes = node.args[0]
first_input_node = input_nodes[0]
input_qspec_map = {}
input_qspec_map[first_input_node] = quantization_config.input_activation
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, node)
)
for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)
def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:
input_qspec_map = {}
input_act = node.args[0]
input_qspec_map[input_act] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
def annotate_matmul_input1(node: Node):
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
act_symmetric=True, act_observer=MinMaxObserver
)
while isinstance(node, Node) and node.op == "call_function":
if node.target in [
torch.ops.aten.permute.default,
torch.ops.aten.transpose.int,
]:
annotate_single_in_single_out(node, quantization_config_8a8w)
node = node.args[0]
elif node.target == torch.ops.aten.cat.default:
annotate_cat(node, quantization_config_8a8w)
node = node.args[0][0]
else:
node = node.args[0]
quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
annotate_matmul_input1(node.args[1])
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
"""
This function is specific for llama matmul op 16a8w.
"""
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec
input_act1 = node.args[1]
input_spec1 = quantization_config.weight
input_qspec_map[input_act1] = input_spec1
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
input = node.args[0]
value = node.args[2]
input_qspec_map = {}
input_qspec_map[input] = quantization_config.input_activation
input_qspec_map[value] = SharedQuantizationSpec((input, node))
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((input, node)),
_annotated=True,
)
def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:
input_qspec_map = {}
input_act = node.args[0]
input_qspec_map[input_act] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
def annotate_cat(node: Node, quantization_config: QuantizationConfig):
input_nodes = node.args[0]
assert isinstance(input_nodes, Sequence)
first_input_node = input_nodes[0]
input_qspec_map = {}
assert isinstance(first_input_node, Node)
assert isinstance(node, Node)
input_qspec_map[first_input_node] = quantization_config.input_activation
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, node)
)
for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
assert isinstance(input_node, Node)
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)
def is_edge_condition(node: Node):
if not isinstance(node, Node) or node.op != "call_function":
return True
return False
def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
if is_edge_condition(node):
return
if node.target in [
torch.ops.aten.index_put.default,
torch.ops.aten.index_put_.default,
]:
annotate_index_put(node, quantization_config)
annotate_matmul_input1(node.args[0], quantization_config)
elif node.target == torch.ops.aten.cat.default:
annotate_cat(node, quantization_config)
# Expect that the inputs of the cat op are select ops
for arg in node.args[0]:
annotate_matmul_input1(arg, quantization_config)
else:
annotate_single_in_single_out(node, quantization_config)
annotate_matmul_input1(node.args[0], quantization_config)
# Annotate 16a8w for matmul op to get better performance
quantization_config_16a8w = get_16a8w_qnn_ptq_config()
# Annotate 8a8w for second input of matmul until past_kv_cache
quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
if "nn_module_stack" in node.meta:
module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[-1][0]
if "SDPA" in full_qualified_name:
annotate_matmul(node, quantization_config_16a8w)
annotate_matmul_input1(node.args[1], quantization_config_8a8w)
def custom_annotate_llama_last_conv_16a8w(gm: torch.fx.GraphModule) -> None:
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec
weight = node.args[1]
input_qspec_map[weight] = quantization_config.weight
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
torch.uint16, weight_dtype=torch.int8
)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
if "nn_module_stack" in node.meta:
module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[0][0]
if full_qualified_name == "L['self'].llama.output":
annotate_conv2d(
node, quantization_config=quantization_config_16a8w_per_channel
)
def custom_annotate_matmul_16a8w(gm: torch.fx.GraphModule):
"""
Annotate matmul op with 16a8w quantization config
"""
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec
input_act1 = node.args[1]
input_spec1 = quantization_config.weight
input_qspec_map[input_act1] = input_spec1
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
# Annotate 16a8w for matmul op to get better performance
quantization_config_16a8w = get_16a8w_qnn_ptq_config()
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
def get_custom_quant_ios_dtype(
cache_shape: torch.Size,
node: torch.fx.Node,
kv_dtype=torch.uint8,
sharding_dtype=torch.uint16,
):
"""
This function is specific for llama inputs and outputs
"""
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
return kv_dtype
# Tag index put node before copy node, because copy is a skipped node in qnn
if (
exir_ops.edge.aten.index_put.default == node.target
and node.meta["val"].shape == cache_shape
):
return kv_dtype
# Tag sharding io
if exir_ops.edge.llama.fallback.default in [
u.target for u in list(node.users.keys())
] + [node.target]:
return sharding_dtype
# Tag index op as quantized tensors. It is caused by sharding
if exir_ops.edge.aten.index.Tensor in [
u.target for u in list(node.users.keys())
] + [node.target]:
return sharding_dtype