blob: 5291edeb9faf6944bda08b0ae65e7485c9d9ba2a [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.utils.constants import (
QCOM_QUANT_ATTRS,
QCOM_QUANTIZED_IO,
QCOM_REQUANTIZE,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
class InsertRequantize(ExportPass):
"""
This pass inserts convert op for operators which have
different quantization specs in input and activation.
Convert OP is a specific op which helps to requantize in Qnn backend
"""
# Storing ops that has multi output but run _single_output_annotation logic
# instead of _multi_output_annotation. Ops might be added into this set because
# we don't use the 2nd output, 2nd output is an integer, etc.
multi_output_op_ignore_set = {
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.topk.default,
}
def __init__(
self,
edge_program: torch.export.ExportedProgram,
):
super(InsertRequantize, self).__init__()
self.edge_program = edge_program
# TODO: Implement this function when we have an op with
# multiple outputs that requires quant attributes.
def _multi_output_annotation(self) -> None:
raise NotImplementedError("requant is not implemented for multi output yet")
def _single_output_annotation(
self, gm: torch.fx.GraphModule, n: torch.fx.node
) -> None:
with gm.graph.inserting_after(n):
users = list(n.users.keys())
inserted_n = gm.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(n,),
)
inserted_n.meta["val"] = n.meta["val"]
inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE)
if n.meta.get(QCOM_QUANTIZED_IO):
inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO]
for user in users:
user.replace_input_with(n, inserted_n)
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for n in graph_module.graph.nodes:
if QCOM_REQUANTIZE in n.meta:
(
self._single_output_annotation(graph_module, n)
if isinstance(
n.meta["val"], torch._subclasses.fake_tensor.FakeTensor
)
or n.target in self.multi_output_op_ignore_set
else self._multi_output_annotation()
)
def call(self, graph_module: torch.fx.GraphModule):
self._insert(graph_module)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)