blob: 351de70dce8b35be3c3dc51a23be7c4f351244f6 [file] [log] [blame]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Sequence
import torch
from torch._ops import OpOverload
from torch._subclasses import FakeTensor
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node
QUANT_ANNOTATION_KEY = "quantization_annotation"
OP_ANNOTATOR: Dict[OpOverload, Callable] = {}
@dataclass(eq=True, frozen=True)
class QuantizationConfig:
input_activation: Optional[QuantizationSpec]
output_activation: Optional[QuantizationSpec]
weight: Optional[QuantizationSpec]
bias: Optional[QuantizationSpec | Callable]
def register_annotator(ops: List[OpOverload]):
def decorator(annotator: Callable):
for op in ops:
OP_ANNOTATOR[op] = annotator
return decorator
def _is_input_float_tensor(node: Node):
"""Check if the input is not a float tensor, so that we can skip quantization for the node
since observers only works with float Tensors
"""
if (
not isinstance(node, Node)
or "val" not in node.meta
or not isinstance(node.meta["val"], FakeTensor)
):
return False
return node.meta["val"].dtype == torch.float32
def _is_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
return True if any of the node
is annotated, otherwise return False
"""
annotated = False
for node in nodes:
annotated = annotated or (
QUANT_ANNOTATION_KEY in node.meta
and node.meta[QUANT_ANNOTATION_KEY]._annotated
)
return annotated
def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if QUANT_ANNOTATION_KEY not in node.meta:
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation()
node.meta[QUANT_ANNOTATION_KEY]._annotated = True
def annotate_in_out_obs_sharing_op(
node: Node, quantization_config: QuantizationConfig
) -> None:
if _is_annotated([node]):
return
input_act = node.args[0]
assert isinstance(input_act, Node)
# only annotate input output sharing operator
# when the output of the input node is annotated
if (
QUANT_ANNOTATION_KEY not in input_act.meta
or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated
or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None
):
return
act_qspec = SharedQuantizationSpec(input_act)
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)
def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
node_tensor = node.meta.get("val")
if torch.is_tensor(node_tensor) and node_tensor.dtype != torch.float32:
return
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
output_act_qspec = quantization_config.output_activation
input_qspec_map = {}
input_act0 = node.args[0]
if _is_input_float_tensor(input_act0):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = node.args[1]
if _is_input_float_tensor(input_act1):
input_qspec_map[input_act1] = input_act_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
@register_annotator([torch.ops.aten.add.Tensor])
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.sub.Tensor])
def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar])
def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor])
def annotate_div(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.rsub.Scalar])
def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.sum.dim_IntList])
def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.ceil.default])
def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.clamp.default])
def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.tanh.default])
def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default]
)
def annotate_hardswish(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardsigmoid_.default]
)
def annotate_hardsigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default])
def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.mean.default])
def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.max_pool2d.default])
def annotate_max_pool2d(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.max_pool2d_with_indices.default])
def annotate_max_pool2d_with_indices(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
def annotate_adaptive_avgpool2d(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.avg_pool2d.default])
def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.permute.default])
def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[
torch.ops.aten.leaky_relu.default,
torch.ops.aten.leaky_relu_.default,
torch.ops.aten.prelu.default,
]
)
def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.view.default])
def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.pixel_shuffle.default])
def annotate_pixel_shuffle_default(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.pixel_unshuffle.default])
def annotate_pixel_unshuffle_default(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.upsample_bilinear2d.vec])
def annotate_upsample_bilinear2d(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.upsample_nearest2d.vec])
def annotate_upsample_nearest2d(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.softmax.int, torch.ops.aten._softmax.default])
def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.log_softmax.int])
def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.pad.default])
def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.reshape.default])
def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.select.int])
def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.mean.dim])
def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.slice.Tensor])
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.sqrt.default])
def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.gelu.default])
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.scaled_dot_product_attention.default])
def annotate_scaled_dot_product_attention(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze_copy.dims,
]
)
def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.rsqrt.default])
def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default])
def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.pow.Tensor_Scalar])
def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.unsqueeze.default])
def annotate_unsqueeze(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[
torch.ops.aten.unsqueeze_copy.default,
]
)
def annotate_unsqueeze_copy(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.transpose.int])
def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.embedding.default])
def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
weight = node.args[0]
input_qspec_map = {}
input_qspec_map[weight] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((weight, node)),
_annotated=True,
)
@register_annotator([torch.ops.aten.expand.default])
def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.flatten.using_ints])
def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.stack.default])
def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
input_qspec_map = {}
for input_act in node.args[0]:
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
node_tensor = node.meta.get("val")
if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
continue
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
@register_annotator([torch.ops.aten.matmul.default])
def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
output_act_qspec = quantization_config.output_activation
input_qspec_map = {}
input_act0 = node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = node.args[1]
if isinstance(input_act1, Node):
# In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
if input_act_qspec.dtype == torch.int32:
input_qspec_map[input_act1] = quantization_config.weight
quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None)
if quantization_annotation:
quantization_annotation.output_qspec = quantization_config.weight
else:
input_qspec_map[input_act1] = input_act_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
@register_annotator([torch.ops.aten.bmm.default])
def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
output_act_qspec = quantization_config.output_activation
input_qspec_map = {}
input_act0 = node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = node.args[1]
if isinstance(input_act1, Node):
# In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
if input_act_qspec.dtype == torch.int32:
input_qspec_map[input_act1] = quantization_config.weight
quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None)
if quantization_annotation:
quantization_annotation.output_qspec = quantization_config.weight
else:
input_qspec_map[input_act1] = input_act_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
node.meta["source_fn_stack"] = [(node, torch.bmm)]
@register_annotator([torch.ops.aten.conv2d.default, torch.ops.aten.conv1d.default])
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec
weight = node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = quantization_config.weight
if len(node.args) > 2:
bias = node.args[2]
if isinstance(bias, Node):
if callable(quantization_config.bias):
input_qspec_map[bias] = quantization_config.bias(node)
else:
input_qspec_map[bias] = quantization_config.bias
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
@register_annotator([torch.ops.aten.linear.default])
def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[1]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[2]
if _is_annotated([node]):
return
_annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
if callable(quantization_config.bias):
bias_config = quantization_config.bias(node)
else:
bias_config = quantization_config.bias
_annotate_input_qspec_map(node, bias_node, bias_config)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
node.meta["source_fn_stack"] = [(node, torch.nn.Linear)]
@register_annotator([torch.ops.aten.layer_norm.default])
def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]
if _is_annotated([node]):
return
_annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
weight_node,
quantization_config.input_activation,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default])
def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None:
input_nodes = node.args[0]
if _is_annotated([node]):
return
assert isinstance(input_nodes, Sequence)
first_input_node = input_nodes[0]
input_qspec_map = {}
assert isinstance(first_input_node, Node)
assert isinstance(node, Node)
input_qspec_map[first_input_node] = quantization_config.input_activation
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, node)
)
for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
assert isinstance(input_node, Node)
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)
@register_annotator([torch.ops.aten.unbind.int])
def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
node_tensor = node.meta.get("val")
if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
return
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
@register_annotator([torch.ops.aten.chunk.default])
def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
for user in node.users:
user.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
output_qspec=quantization_config.output_activation,
_annotated=True,
)