blob: e315599cf7f3d1286bdb54365a5b79cb989afea9 [file]
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.
from typing import Callable, List
import torch
from torch._ops import OpOverload
from torch._subclasses import FakeTensor
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.export import export_for_training
from torch.fx import Graph, Node
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
SubgraphMatcherWithNameNodeMap,
)
from .qconfig import QuantizationConfig
OP_TO_ANNOTATOR = {}
def annotate(graph: Graph, quant_config: QuantizationConfig) -> None:
# Pattern annotation
_annotate_rmsnorm_pattern(graph, quant_config)
_annotate_fused_activation_pattern(graph, quant_config)
# Per-op annotation
for node in graph.nodes:
if node.op == "placeholder":
annotate_placeholder(node, quant_config)
elif node.op == "call_function":
annotate_func = OP_TO_ANNOTATOR.get(node.target, None)
if annotate_func is not None:
annotate_func(node, quant_config)
def register_annotator(ops: List[OpOverload]):
def decorator(annotator_fn: Callable):
for op in ops:
OP_TO_ANNOTATOR[op] = annotator_fn
return decorator
def _is_annotated(node: Node):
"""
Given a list of nodes (that represents an operator pattern),
return True if any of the node
is annotated, otherwise return False
"""
KEY = "quantization_annotation"
return KEY in node.meta and node.meta[KEY]._annotated
def _mark_as_annotated(nodes: List[Node]):
KEY = "quantization_annotation"
for node in nodes:
if KEY not in node.meta:
node.meta[KEY] = QuantizationAnnotation()
node.meta[KEY]._annotated = True
def _is_float_activation_tensor(node: Node):
if not isinstance(node, Node):
return False
if "val" not in node.meta:
return False
if not isinstance(node.meta["val"], FakeTensor):
return False
return node.meta["val"].dtype == torch.float32
def _annotate_fused_activation_pattern(
graph: Graph, quant_config: QuantizationConfig
) -> None:
for relu_node in graph.nodes:
# Check relu/relu6 node
if relu_node.op != "call_function":
continue
if relu_node.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.relu6.default,
]:
continue
producer_node = relu_node.args[0]
if not isinstance(producer_node, Node):
continue
if producer_node.op != "call_function":
continue
if len(producer_node.users) != 1:
continue
# Handle affine + relu fusion
if producer_node.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
]:
weight_node = producer_node.args[1]
_annotate_input_qspec_map(
producer_node,
weight_node,
quant_config.weight,
)
_annotate_output_qspec(relu_node, quant_config.activation)
_mark_as_annotated([producer_node, weight_node, relu_node])
continue
# Handle arithmetic + relu fusion
if producer_node.target in [
torch.ops.aten.add.Scalar,
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Scalar,
torch.ops.aten.add_.Tensor,
torch.ops.aten.div.Scalar,
torch.ops.aten.div.Tensor,
torch.ops.aten.div_.Scalar,
torch.ops.aten.div_.Tensor,
torch.ops.aten.divide.Scalar,
torch.ops.aten.divide.Tensor,
torch.ops.aten.mul.Scalar,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mul_.Scalar,
torch.ops.aten.mul_.Tensor,
torch.ops.aten.rsub.Scalar,
torch.ops.aten.rsub.Tensor,
torch.ops.aten.sub.Scalar,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Scalar,
torch.ops.aten.sub_.Tensor,
]:
_annotate_output_qspec(relu_node, quant_config.activation)
_mark_as_annotated([producer_node, relu_node])
continue
def _annotate_rmsnorm_pattern(graph: Graph, quant_config: QuantizationConfig) -> None:
class ExecuTorchPattern(torch.nn.Module):
def forward(self, x):
norm = x * torch.rsqrt((x * x).mean(-1, keepdim=True) + 1e-6)
return norm, {}
class MTKPattern(torch.nn.Module):
def forward(self, x):
norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6)
return norm, {}
for pattern_cls in (ExecuTorchPattern, MTKPattern):
pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module()
matcher = SubgraphMatcherWithNameNodeMap(
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
)
matches = matcher.match(graph)
for match in matches:
target_nodes = []
for node in match.nodes_map.values():
if node in match.placeholder_nodes:
continue
if node.op == "call_function" and node.target in OP_TO_ANNOTATOR:
target_nodes.append(node)
if any(_is_annotated(node) for node in target_nodes):
continue
_mark_as_annotated(target_nodes)
for node in match.returning_nodes:
_annotate_output_qspec(node, quant_config.activation)
def annotate_placeholder(node: Node, quant_config: QuantizationConfig) -> None:
if _is_annotated(node):
return
if _is_float_activation_tensor(node):
_annotate_output_qspec(node, quant_config.activation)
_mark_as_annotated([node])
@register_annotator(
[
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
]
)
def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None:
if _is_annotated(node):
return
weight_node = node.args[1]
_annotate_input_qspec_map(
node,
weight_node,
quant_config.weight,
)
_annotate_output_qspec(node, quant_config.activation)
# Make weight as annotated because it is a constant node
_mark_as_annotated([node, weight_node])
@register_annotator(
[
torch.ops.aten.add.Scalar,
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Scalar,
torch.ops.aten.add_.Tensor,
torch.ops.aten.bmm.default,
torch.ops.aten.div.Scalar,
torch.ops.aten.div.Tensor,
torch.ops.aten.div_.Scalar,
torch.ops.aten.div_.Tensor,
torch.ops.aten.divide.Scalar,
torch.ops.aten.divide.Tensor,
torch.ops.aten.gelu.default,
torch.ops.aten.group_norm.default,
torch.ops.aten.layer_norm.default,
torch.ops.aten.leaky_relu.default,
torch.ops.aten.matmul.default,
torch.ops.aten.mul.Scalar,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mul_.Scalar,
torch.ops.aten.mul_.Tensor,
torch.ops.aten.pow.Scalar,
torch.ops.aten.pow.Tensor_Scalar,
torch.ops.aten.pow.Tensor_Tensor,
torch.ops.aten.prelu.default,
torch.ops.aten.rsub.Scalar,
torch.ops.aten.rsub.Tensor,
torch.ops.aten.silu.default,
torch.ops.aten.sub.Scalar,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Scalar,
torch.ops.aten.sub_.Tensor,
]
)
def annotate_output_qspec(node: Node, quant_config: QuantizationConfig) -> None:
if _is_annotated(node):
return
_annotate_output_qspec(node, quant_config.activation)
_mark_as_annotated([node])
@register_annotator([torch.ops.aten.embedding.default])
def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None:
if _is_annotated(node):
return
wgt_node = node.args[0]
_annotate_input_qspec_map(node, wgt_node, quant_config.activation)
_mark_as_annotated([node])