| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # Copyright 2024 Arm Limited and/or its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| # |
| # Utility functions for ArmQuantizer |
| # |
| |
| import operator |
| from typing import Callable, cast, List |
| |
| import torch |
| from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig |
| from torch._subclasses import FakeTensor |
| |
| from torch.ao.quantization.quantizer import ( |
| QuantizationAnnotation, |
| SharedQuantizationSpec, |
| ) |
| from torch.fx import GraphModule, Node |
| |
| |
| def is_annotated(node: Node) -> bool: |
| """Given a node return whether the node is annotated.""" |
| return ( |
| "quantization_annotation" in node.meta |
| and cast( |
| QuantizationAnnotation, node.meta["quantization_annotation"] |
| )._annotated |
| ) |
| |
| |
| def are_annotated(nodes: List[Node]) -> bool: |
| """Given a list of nodes (that represents an operator pattern), |
| return True if any of the nodes |
| is annotated, otherwise return False. |
| """ |
| for node in nodes: |
| if is_annotated(node): |
| return True |
| return False |
| |
| |
| def mark_nodes_as_annotated(nodes: List[Node]) -> None: |
| """Marks all nodes in list 'nodes' as annotated. If needed, an empty |
| QuantizationAnnotation is added to the quantization_annotation node meta entry. |
| """ |
| for node in nodes: |
| if node is not None: |
| if "quantization_annotation" not in node.meta: |
| node.meta["quantization_annotation"] = QuantizationAnnotation() |
| node.meta["quantization_annotation"]._annotated = True |
| |
| |
| def get_shared_qspec( |
| node: Node, gm: GraphModule, quantization_config: QuantizationConfig |
| ): |
| """Returns a Quantization constallation with a SharedQuantizationSpec for the inputs |
| and output to the parameter 'node'. |
| Parameters: |
| node: a node with two inputs that should share Quantization parameters. |
| gm: The GraphModule containing the node. Used to inspect global graph features. |
| quantization_config : a QuantizationConfig with the input QuantizationSpec to share |
| Returns: |
| input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to |
| the correct QuantizationSpec. |
| shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec. |
| |
| Both outputs are None if one of the inputs is a node that can't be quantized. |
| """ |
| input_act0 = cast(Node, node.args[0]) |
| input_act1 = node.args[1] |
| |
| input_act_qspec = quantization_config.get_input_act_qspec() |
| shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) |
| |
| input_qspec_map = {} |
| if isinstance(input_act0, Node): |
| if not is_input_ok_for_quantization(input_act0, gm): |
| return None, None |
| input_qspec_map[input_act0] = input_act_qspec |
| |
| if isinstance(input_act1, Node): |
| if not is_input_ok_for_quantization(input_act1, gm): |
| return None, None |
| if input_act0 is not input_act1: |
| input_qspec_map[input_act1] = shared_with_input0_qspec |
| return input_qspec_map, shared_with_input0_qspec |
| |
| |
| def is_input_ok_for_quantization(input_act: Node, gm: GraphModule): |
| """Check if an input can be quantized. The input can not be quantized if: |
| - The node does not output a float tensor or, |
| - The node outputs a large scalar. |
| """ |
| return not ( |
| is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm) |
| ) |
| |
| |
| def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): |
| targets = target_str.split(".") |
| for target in targets[:-1]: |
| module = module.get_submodule(target) |
| return getattr(module, targets[-1]) |
| |
| |
| def is_input_large_scalar(node: Node, gm: GraphModule): |
| """Check if input is a large scalar value. So that we can skip quantization for the node |
| since histc op (in HistogramObserver) only works for values up to certain upper bound |
| """ |
| if node.op == "get_attr" and isinstance(node.target, str): |
| tensor = get_node_target(gm, node.target) |
| # torch.histc works until this upper bound |
| HISTC_UPPER_BOUND = 3.4028235e15 |
| return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND |
| return False |
| |
| |
| def is_input_non_float_tensor(node: Node) -> bool: |
| """Check if the input is not a float tensor, so that we can skip quantization for the node |
| since observers only works with float Tensors |
| """ |
| if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): |
| return True |
| return node.meta["val"].dtype != torch.float32 |
| |
| |
| def is_share_obs_or_fq_op(op: Callable) -> bool: |
| """Returns whether the the operation 'op' can be quantized using a shared observer or |
| fake quantizer. This means that the operation can inherit it's quantization spec |
| from parent nodes. |
| """ |
| return op in [ |
| torch.ops.aten.hardtanh.default, |
| torch.ops.aten.hardtanh_.default, |
| torch.ops.aten.relu.default, |
| torch.ops.aten.mean.default, |
| torch.ops.aten.mean.dim, |
| torch.ops.aten.permute.default, |
| torch.ops.aten.permute_copy.default, |
| # TODO: remove? |
| torch.ops.aten.adaptive_avg_pool2d.default, |
| torch.ops.aten.avg_pool2d.default, |
| torch.ops.aten.max_pool2d.default, |
| torch.ops.aten.full.default, |
| torch.ops.aten.flatten.using_ints, |
| torch.ops.aten.dropout.default, |
| operator.getitem, |
| ] |
| |
| |
| def propagate_annotation(model: GraphModule) -> None: |
| """For unannotated ops that can share observer or have fake quantizers, |
| annotate with a SharedQuantizationSpec, where the shared spec is the |
| output spec of the parent node. |
| This propagates output qspecs downward in the graph until |
| an op that is already annotated or can't share qspec is encountered. |
| """ |
| for n in model.graph.nodes: |
| n = cast(Node, n) |
| if is_annotated(n): |
| continue |
| if n.op != "call_function" or not is_share_obs_or_fq_op( |
| cast(Callable, n.target) |
| ): |
| continue |
| |
| prev_node = n.args[0] |
| if not isinstance(prev_node, Node): |
| continue |
| |
| quantization_annotation = cast( |
| QuantizationAnnotation | None, |
| prev_node.meta.get("quantization_annotation", None), |
| ) |
| if not quantization_annotation or not quantization_annotation.output_qspec: |
| continue |
| |
| # propagate the previous output_qspec to the current node |
| shared_qspec = SharedQuantizationSpec(prev_node) |
| n.meta["quantization_annotation"] = QuantizationAnnotation( |
| input_qspec_map={ |
| prev_node: shared_qspec, |
| }, |
| output_qspec=shared_qspec, |
| _annotated=True, |
| ) |