| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import torch |
| |
| from executorch.backends.qualcomm.utils.constants import ( |
| QCOM_QUANT_ATTRS, |
| QCOM_QUANTIZED_IO, |
| QCOM_REQUANTIZE, |
| ) |
| |
| from executorch.exir.dialects._ops import ops as exir_ops |
| from executorch.exir.pass_base import ExportPass, PassResult |
| |
| |
| class InsertRequantize(ExportPass): |
| """ |
| This pass inserts convert op for operators which have |
| different quantization specs in input and activation. |
| Convert OP is a specific op which helps to requantize in Qnn backend |
| """ |
| |
| # Storing ops that has multi output but run _single_output_annotation logic |
| # instead of _multi_output_annotation. Ops might be added into this set because |
| # we don't use the 2nd output, 2nd output is an integer, etc. |
| multi_output_op_ignore_set = { |
| exir_ops.edge.aten._native_batch_norm_legit_no_training.default, |
| exir_ops.edge.aten.topk.default, |
| } |
| |
| def __init__( |
| self, |
| edge_program: torch.export.ExportedProgram, |
| ): |
| super(InsertRequantize, self).__init__() |
| self.edge_program = edge_program |
| |
| # TODO: Implement this function when we have an op with |
| # multiple outputs that requires quant attributes. |
| def _multi_output_annotation(self) -> None: |
| raise NotImplementedError("requant is not implemented for multi output yet") |
| |
| def _single_output_annotation( |
| self, gm: torch.fx.GraphModule, n: torch.fx.node |
| ) -> None: |
| with gm.graph.inserting_after(n): |
| users = list(n.users.keys()) |
| inserted_n = gm.graph.create_node( |
| "call_function", |
| exir_ops.edge.aten._to_copy.default, |
| (n,), |
| ) |
| |
| inserted_n.meta["val"] = n.meta["val"] |
| inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE) |
| if n.meta.get(QCOM_QUANTIZED_IO): |
| inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO] |
| |
| for user in users: |
| user.replace_input_with(n, inserted_n) |
| |
| def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for n in graph_module.graph.nodes: |
| if QCOM_REQUANTIZE in n.meta: |
| ( |
| self._single_output_annotation(graph_module, n) |
| if isinstance( |
| n.meta["val"], torch._subclasses.fake_tensor.FakeTensor |
| ) |
| or n.target in self.multi_output_op_ignore_set |
| else self._multi_output_annotation() |
| ) |
| |
| def call(self, graph_module: torch.fx.GraphModule): |
| self._insert(graph_module) |
| graph_module.graph.eliminate_dead_code() |
| graph_module.recompile() |
| return PassResult(graph_module, True) |