| # Copyright (c) Qualcomm Innovation Center, Inc. |
| # All rights reserved |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import torch |
| from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS |
| from executorch.exir.pass_base import ExportPass, PassResult |
| from torch.fx.passes.utils.source_matcher_utils import get_source_partitions |
| |
| from .utils import dq_ops, get_quant_attrs, q_ops |
| |
| |
| class AnnotateDecomposed(ExportPass): |
| """ |
| Add "quant_attrs" to graph nodes' meta from the QDQ information |
| generated after quantization process. |
| """ |
| |
| def __init__(self, edge_program: torch.export.ExportedProgram): |
| super(AnnotateDecomposed, self).__init__() |
| self.edge_program = edge_program |
| |
| def _annotate_unbind(self, graph_module: torch.fx.GraphModule): |
| partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"]) |
| for _, src_partitions in partitions.items(): |
| for src_partition in src_partitions: |
| if src_partition.input_nodes[0].target in dq_ops: |
| q_node = src_partition.input_nodes[0].args[0] |
| quant_attrs = get_quant_attrs(self.edge_program, q_node) |
| for n in src_partition.nodes: |
| n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() |
| |
| def _annotate_stack(self, graph_module: torch.fx.GraphModule): |
| partitions = get_source_partitions(graph_module.graph, [torch.stack]) |
| for _, src_partitions in partitions.items(): |
| for src_partition in src_partitions: |
| output = src_partition.output_nodes[0] |
| if (list(output.users)[0].target) in q_ops: |
| quant_attrs = get_quant_attrs( |
| self.edge_program, list(output.users)[0] |
| ) |
| for n in src_partition.nodes: |
| n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() |
| |
| def call(self, graph_module: torch.fx.GraphModule): |
| self._annotate_unbind(graph_module) |
| self._annotate_stack(graph_module) |
| graph_module.recompile() |
| return PassResult(graph_module, True) |