blob: a8a757ce9bfdb33eeebd6a87d8108d1025718972 [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
from .utils import dq_ops, get_quant_attrs, q_ops
class AnnotateDecomposed(ExportPass):
"""
Add "quant_attrs" to graph nodes' meta from the QDQ information
generated after quantization process.
"""
def __init__(self, edge_program: torch.export.ExportedProgram):
super(AnnotateDecomposed, self).__init__()
self.edge_program = edge_program
def _annotate_unbind(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"])
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
if src_partition.input_nodes[0].target in dq_ops:
q_node = src_partition.input_nodes[0].args[0]
quant_attrs = get_quant_attrs(self.edge_program, q_node)
for n in src_partition.nodes:
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
def _annotate_stack(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(graph_module.graph, [torch.stack])
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
output = src_partition.output_nodes[0]
if (list(output.users)[0].target) in q_ops:
quant_attrs = get_quant_attrs(
self.edge_program, list(output.users)[0]
)
for n in src_partition.nodes:
n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy()
def call(self, graph_module: torch.fx.GraphModule):
self._annotate_unbind(graph_module)
self._annotate_stack(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)