| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import operator |
| from itertools import accumulate |
| from typing import cast |
| |
| import torch |
| from executorch.exir.backend.canonical_partitioners.config_partitioner import ( |
| format_target_name, |
| ) |
| |
| _Q_OPS = { |
| "quantize_per_tensor.tensor", |
| "quantize_per_tensor.default", |
| "quantize_per_channel.default", |
| "quantize_per_channel_group.default", |
| "quantize_per_token.default", |
| "quantize_affine.default", |
| } |
| |
| _DQ_OPS = { |
| "dequantize_per_tensor.tensor", |
| "dequantize_per_tensor.default", |
| "dequantize_per_channel.default", |
| "dequantize_per_channel_group.default", |
| "dequantize_per_token.default", |
| "dequantize_affine.default", |
| } |
| |
| |
| _QPARAM_OPS = { |
| "choose_qparams.tensor", |
| "choose_qparams_per_token_asymmetric.default", |
| "choose_qparams_affine.default", |
| } |
| |
| _DYNAMIC_OPS = { |
| "quantize_per_tensor.tensor", |
| "quantize_per_token.default", |
| "dequantize_per_tensor.tensor", |
| "dequantize_per_token.default", |
| } |
| |
| |
| def is_dynamic_qdq(node: torch.fx.Node) -> bool: |
| if node.op != "call_function": |
| return False |
| node_name = format_target_name(node.target.__name__) # pyre-ignore |
| is_dynamic_affine = is_per_token(node) and not is_per_channel_group(node) |
| |
| return node_name in _DYNAMIC_OPS or is_dynamic_affine |
| |
| |
| def is_qparam(node: torch.fx.Node) -> bool: |
| if node.op != "call_function": |
| return False |
| node_name = format_target_name(node.target.__name__) # pyre-ignore |
| |
| return node_name in _QPARAM_OPS |
| |
| |
| def is_quant(node: torch.fx.Node) -> bool: |
| if node.op != "call_function": |
| return False |
| node_name = format_target_name(node.target.__name__) # pyre-ignore |
| |
| return node_name in _Q_OPS |
| |
| |
| def is_dequant(node: torch.fx.Node) -> bool: |
| if node.op != "call_function": |
| return False |
| node_name = format_target_name(node.target.__name__) # pyre-ignore |
| |
| return node_name in _DQ_OPS |
| |
| |
| def is_per_channel(node: torch.fx.Node) -> bool: |
| if not (is_quant(node) or is_dequant(node)): |
| return False |
| |
| is_affine_per_channel_group = is_per_channel_group(node) |
| is_per_channel = "per_channel" in node.target.__name__ # pyre-ignore |
| |
| return is_per_channel or is_affine_per_channel_group |
| |
| |
| def is_affine_qdq(node: torch.fx.Node) -> bool: |
| if not (is_quant(node) or is_dequant(node)): |
| return False |
| |
| return "quantize_affine" in node.target.__name__ # pyre-ignore |
| |
| |
| def _get_block_size_input_scale(node: torch.fx.Node): |
| assert is_affine_qdq(node) |
| block_size = node.args[1] |
| input_val = node.all_input_nodes[0].meta["val"] |
| scale_val = node.all_input_nodes[1].meta["val"] |
| return block_size, input_val, scale_val |
| |
| |
| def is_per_token(node: torch.fx.Node): |
| if not (is_quant(node) or is_dequant(node)): |
| return False |
| |
| if "per_token" in node.target.__name__: # pyre-ignore |
| return True |
| elif is_affine_qdq(node): |
| block_size, input_val, scale_val = _get_block_size_input_scale(node) |
| flag = True |
| scale_numel_expected = 1 |
| for i in range(len(block_size) - 1): |
| flag &= block_size[i] == 1 |
| scale_numel_expected *= input_val.shape[i] |
| |
| flag &= block_size[-1] == input_val.shape[-1] |
| flag &= scale_val.numel() == scale_numel_expected |
| return flag |
| |
| return False |
| |
| |
| def is_per_channel_group(node: torch.fx.Node): |
| if not (is_quant(node) or is_dequant(node)): |
| return False |
| |
| if "per_channel_group" in node.target.__name__: # pyre-ignore |
| return True |
| elif is_affine_qdq(node): |
| block_size, input_val, scale_val = _get_block_size_input_scale(node) |
| flag = True |
| flag &= len(block_size) == 2 |
| flag &= block_size[0] == 1 |
| group_size = block_size[1] |
| scale_numel = list(accumulate(scale_val.shape, operator.mul))[-1] |
| input_numel = list(accumulate(input_val.shape, operator.mul))[-1] |
| flag &= input_numel == group_size * scale_numel |
| return flag |
| |
| return False |
| |
| |
| def extract_qdq_affine_op_args_for_decomposed_ops(node: torch.fx.Node): |
| if not is_affine_qdq(node): |
| return None, None |
| # make sure input_dtype and zero_point_domain have expected values |
| input_node = node.args[0] |
| scale_node = node.args[2] |
| zero_point_node = node.args[3] |
| args = [input_node, scale_node, zero_point_node] |
| assert ( |
| len(node.args) > 4 |
| ), f"expecting at least 6 args, got node: {node.format_node()}" |
| |
| if node.args[4] != torch.int8: |
| return None, None |
| target_dtype = cast(torch.dtype, node.args[4]) |
| |
| if len(node.args) > 6: |
| # quant_min |
| args.append(node.args[5]) |
| # quant_max |
| args.append(node.args[6]) |
| else: |
| dtype_info = torch.iinfo(target_dtype) |
| quant_min = dtype_info.min |
| quant_max = dtype_info.max |
| args.append(quant_min) |
| args.append(quant_max) |
| |
| # add target_dtype_node after quant_min/quant_max |
| args.append(target_dtype) |
| # zero_point_domain |
| if len(node.args) > 7 and node.args[7] != "INT": |
| return None, None |
| |
| if is_per_channel_group(node): |
| block_sizes = cast(list[int], node.args[1]) |
| args.append(block_sizes[-1]) |
| |
| args.append(node.args[-1]) |
| |
| return args |