blob: 275da567e8f31500c9b6a6cf1d25f93f1c19aff0 [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import numbers
import operator
from functools import partial
from typing import Callable, Dict, List, Sequence, Tuple
import torch
from torch._ops import OpOverload
from torch._subclasses import FakeTensor
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
from torch.ao.quantization.observer import FixedQParamsObserver
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node
from .qconfig import (
get_16a16w_qnn_ptq_config,
get_16a4w_qnn_qat_config,
get_8a8w_qnn_qat_config,
QuantizationConfig,
)
QUANT_ANNOTATION_KEY = "quantization_annotation"
OP_ANNOTATOR: Dict[OpOverload, Callable] = {}
def register_annotator(ops: List[OpOverload]):
def decorator(annotator: Callable):
for op in ops:
OP_ANNOTATOR[op] = annotator
return decorator
def _is_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
return True if any of the node
is annotated, otherwise return False
"""
annotated = False
for node in nodes:
annotated = annotated or (
QUANT_ANNOTATION_KEY in node.meta
and node.meta[QUANT_ANNOTATION_KEY]._annotated
)
return annotated
def _is_float_tensor(node: Node):
"""Check if the node's tensor is a float tensor, so that we can skip quantization for the node
since observers only works with float Tensors
"""
if (
not isinstance(node, Node)
or "val" not in node.meta
or not isinstance(node.meta["val"], FakeTensor)
):
return False
return node.meta["val"].dtype == torch.float32
def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if QUANT_ANNOTATION_KEY not in node.meta:
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation()
node.meta[QUANT_ANNOTATION_KEY]._annotated = True
def annotate_in_out_obs_sharing_op(
node: Node, quantization_config: QuantizationConfig
) -> None:
if _is_annotated([node]):
return
input_act = node.args[0]
assert isinstance(input_act, Node)
# only annotate input output sharing operator
# when the output of the input node is annotated
if (
QUANT_ANNOTATION_KEY not in input_act.meta
or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated
or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None
):
return
act_qspec = SharedQuantizationSpec(input_act)
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map={
input_act: act_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)
def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
if _is_float_tensor(node):
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
@register_annotator([torch.ops.aten.topk.default])
def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
# We can use single_in_single_out since we don't want to quantize indices output
annotate_single_in_single_out(node, quantization_config)
def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
output_act_qspec = (
quantization_config.output_activation if _is_float_tensor(node) else None
)
input_qspec_map = {}
input_act0 = node.args[0]
if _is_float_tensor(input_act0):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = node.args[1]
if _is_float_tensor(input_act1):
input_qspec_map[input_act1] = input_act_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor])
def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator(
[torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar]
)
def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator(
[torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor]
)
def annotate_div(node: Node, quantization_config: QuantizationConfig) -> None:
def _derived_inp1_const_div_quant_spec(
node: torch.fx.Node, output_qspec: QuantizationSpec
) -> DerivedQuantizationSpec:
def _derive_div_qparams_fn(
obs_or_fqs: List,
const_val: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
inp_0_obs_or_fq = obs_or_fqs[0]
inp_0_scale, inp_0_zp = inp_0_obs_or_fq.calculate_qparams()
derived_scale = inp_0_scale / const_val
return (derived_scale, inp_0_zp)
inp_0 = node.args[0]
const_inp_1 = node.args[1]
_derive_div_qparams_with_const_fn = partial(
_derive_div_qparams_fn, const_val=const_inp_1
)
q_min = (
torch.iinfo(output_qspec.dtype).min
if output_qspec.quant_min is None
else output_qspec.quant_min
)
q_max = (
torch.iinfo(output_qspec.dtype).max
if output_qspec.quant_max is None
else output_qspec.quant_max
)
return DerivedQuantizationSpec(
derived_from=[(inp_0, node)],
derive_qparams_fn=_derive_div_qparams_with_const_fn,
dtype=output_qspec.dtype,
quant_min=q_min,
quant_max=q_max,
ch_axis=0,
qscheme=output_qspec.qscheme,
)
if [a for a in node.args if isinstance(a, Node)]:
annotate_binary(node, quantization_config)
# special constant divisor case
elif isinstance(node.args[0], Node) and isinstance(node.args[1], numbers.Number):
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
output_act_qspec = _derived_inp1_const_div_quant_spec(
node, quantization_config.output_activation
)
input_qspec_map = {}
input_act0 = node.args[0]
if _is_float_tensor(input_act0):
input_qspec_map[input_act0] = input_act_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
else:
raise NotImplementedError(f"No quant annotation is implemented for {node}.")
@register_annotator([torch.ops.aten.rsub.Scalar])
def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.sum.dim_IntList])
def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_binary(node, quantization_config)
@register_annotator([torch.ops.aten.ceil.default])
def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.clamp.default])
def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.tanh.default])
def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default]
)
def annotate_hardswish(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardsigmoid_.default]
)
def annotate_hardsigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default])
def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.mean.default])
def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.max_pool2d.default])
def annotate_max_pool2d(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.max_pool2d_with_indices.default])
def annotate_max_pool2d_with_indices(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
def annotate_adaptive_avgpool2d(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.avg_pool2d.default])
def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.permute.default])
def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[
torch.ops.aten.leaky_relu.default,
torch.ops.aten.leaky_relu_.default,
torch.ops.aten.prelu.default,
]
)
def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default])
def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.pixel_shuffle.default])
def annotate_pixel_shuffle_default(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.pixel_unshuffle.default])
def annotate_pixel_unshuffle_default(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.upsample_bilinear2d.vec])
def annotate_upsample_bilinear2d(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.upsample_nearest2d.vec])
def annotate_upsample_nearest2d(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[
torch.ops.aten.softmax.int,
torch.ops.aten._softmax.default,
torch.ops.aten._safe_softmax.default,
]
)
def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.log_softmax.int])
def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.pad.default])
def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.reshape.default])
def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.select.int])
def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.mean.dim])
def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.slice.Tensor])
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.sqrt.default])
def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.gelu.default])
def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.scaled_dot_product_attention.default])
def annotate_scaled_dot_product_attention(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[
torch.ops.aten.squeeze.default,
torch.ops.aten.squeeze.dim,
torch.ops.aten.squeeze_copy.dims,
]
)
def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.rms_norm.default])
def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
if _is_annotated([node]):
return
# TODO current only support 16a16w
_annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
weight_node,
quantization_config.input_activation,
)
nodes_to_mark_annotated = [node]
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
@register_annotator([torch.ops.aten.rsqrt.default])
def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default])
def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
input_qspec_map[input_act] = quantization_config.input_activation
assert isinstance(input_act, Node)
out_qconf = quantization_config.output_activation
q_max = (
torch.iinfo(out_qconf.dtype).max
if out_qconf.quant_max is None
else out_qconf.quant_max
)
q_min = (
torch.iinfo(out_qconf.dtype).min
if out_qconf.quant_min is None
else out_qconf.quant_min
)
scale = 1 / (q_max - q_min + 1)
bias_obs_ctr = observer = FixedQParamsObserver.with_args(
scale=scale,
zero_point=0,
dtype=quantization_config.output_activation.dtype,
qscheme=torch.torch.per_tensor_affine,
quant_max=q_max,
quant_min=q_min,
)
if quantization_config in (
get_8a8w_qnn_qat_config(),
get_16a4w_qnn_qat_config(),
):
bias_obs_ctr = FixedQParamsFakeQuantize.with_args(
observer=observer,
scale=scale,
zero_point=0,
dtype=quantization_config.output_activation.dtype,
qscheme=torch.torch.per_tensor_affine,
quant_max=q_max,
quant_min=q_min,
)
# make sigmoid map to the range between 0~1
out_act_quantization_spec = QuantizationSpec(
dtype=quantization_config.output_activation.dtype,
quant_max=q_max,
quant_min=q_min,
observer_or_fake_quant_ctr=bias_obs_ctr,
qscheme=torch.torch.per_tensor_affine,
)
if _is_float_tensor(node):
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=out_act_quantization_spec,
_annotated=True,
)
@register_annotator([torch.ops.aten.pow.Tensor_Scalar])
def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.unsqueeze.default])
def annotate_unsqueeze(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator(
[
torch.ops.aten.unsqueeze_copy.default,
]
)
def annotate_unsqueeze_copy(
node: Node, quantization_config: QuantizationConfig
) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.transpose.int])
def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.embedding.default])
def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
weight = node.args[0]
input_qspec_map = {}
input_qspec_map[weight] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((weight, node)),
_annotated=True,
)
@register_annotator([torch.ops.aten.index.Tensor])
def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
input_qspec_map = {}
input = node.args[0]
input_qspec_map[input] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((input, node)),
_annotated=True,
)
@register_annotator(
[torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default]
)
def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
input = node.args[0]
value = node.args[2]
input_qspec_map = {}
input_qspec_map[input] = quantization_config.input_activation
input_qspec_map[value] = SharedQuantizationSpec((input, node))
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=SharedQuantizationSpec((input, node)),
_annotated=True,
)
@register_annotator([torch.ops.aten.expand.default])
def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.group_norm.default])
def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]
if _is_annotated([node]):
return
_annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
@register_annotator([torch.ops.aten.flatten.using_ints])
def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_in_out_obs_sharing_op(node, quantization_config)
if not _is_annotated([node]):
annotate_single_in_single_out(node, quantization_config)
@register_annotator([torch.ops.aten.stack.default])
def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
input_qspec_map = {}
for input_act in node.args[0]:
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
node_tensor = node.meta.get("val")
if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
continue
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
@register_annotator([torch.ops.aten.matmul.default])
def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
output_act_qspec = quantization_config.output_activation
input_qspec_map = {}
input_act0 = node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = node.args[1]
if isinstance(input_act1, Node):
# In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
if input_act_qspec.dtype == torch.int32:
# we should use int16 for mm / bmm instead of int4
input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight
else:
input_qspec_map[input_act1] = input_act_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
@register_annotator([torch.ops.aten.bmm.default])
def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
output_act_qspec = quantization_config.output_activation
input_qspec_map = {}
input_act0 = node.args[0]
if isinstance(input_act0, Node):
input_qspec_map[input_act0] = input_act_qspec
input_act1 = node.args[1]
if isinstance(input_act1, Node):
# In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
if input_act_qspec.dtype == torch.int32:
# we should use int16 for mm / bmm instead of int4
input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight
else:
input_qspec_map[input_act1] = input_act_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
node.meta["source_fn_stack"] = [(node, torch.bmm)]
@register_annotator(
[
torch.ops.aten.conv2d.default,
torch.ops.aten.conv1d.default,
torch.ops.aten.conv_transpose2d.input,
]
)
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec
weight = node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = quantization_config.weight
if len(node.args) > 2:
bias = node.args[2]
if isinstance(bias, Node):
if callable(quantization_config.bias):
input_qspec_map[bias] = quantization_config.bias(node)
else:
input_qspec_map[bias] = quantization_config.bias
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)
@register_annotator([torch.ops.aten.linear.default])
def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[1]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[2]
if _is_annotated([node]):
return
_annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
if callable(quantization_config.bias):
bias_config = quantization_config.bias(node)
else:
bias_config = quantization_config.bias
_annotate_input_qspec_map(node, bias_node, bias_config)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
node.meta["source_fn_stack"] = [(node, torch.nn.Linear)]
@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default])
def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act, weight, bias = node.args[0:3]
if _is_annotated([node]):
return
_annotate_input_qspec_map(
node,
act,
quantization_config.input_activation,
)
# QNN requires uint8 instead of int8 in 'weight' config
_annotate_input_qspec_map(
node,
weight,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
node,
bias,
quantization_config.bias,
)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node, *node.args[0:3]])
@register_annotator([operator.getitem])
def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
if _is_float_tensor(node):
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node])
@register_annotator([torch.ops.aten.layer_norm.default])
def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None:
act_node = node.args[0]
weight_node = node.args[2]
bias_node = None
if len(node.args) > 2:
bias_node = node.args[3]
if _is_annotated([node]):
return
input_act_qspec = quantization_config.input_activation
_annotate_input_qspec_map(
node,
act_node,
input_act_qspec,
)
if input_act_qspec.dtype == torch.int32:
_annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
_annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)
@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default])
def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None:
input_nodes = node.args[0]
if _is_annotated([node]):
return
assert isinstance(input_nodes, Sequence)
first_input_node = input_nodes[0]
input_qspec_map = {}
assert isinstance(first_input_node, Node)
assert isinstance(node, Node)
input_qspec_map[first_input_node] = quantization_config.input_activation
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, node)
)
for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
assert isinstance(input_node, Node)
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
_annotated=True,
)
@register_annotator([torch.ops.aten.unbind.int])
def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
node_tensor = node.meta.get("val")
if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
return
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default])
def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)
for user in node.users:
user.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
output_qspec=quantization_config.output_activation,
_annotated=True,
)