blob: bfe19094a60974592498676c98857cfd140d1e81 [file] [log] [blame]
#
# Copyright (c) 2024 Apple Inc. All rights reserved.
# Provided subject to the LICENSE file in the top level directory.
#
import torch
from executorch.exir.dialects._ops import ops as exir_ops
DQ_GROUP_TARGETS = {
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
}
Q_GROUP_TARGETS = {
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
}
DQ_TARGETS = {
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
}.union(DQ_GROUP_TARGETS)
Q_TARGETS = {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
}.union(Q_GROUP_TARGETS)
def is_quant(tensor: torch.fx.Node) -> bool:
return tensor.target in Q_TARGETS
def is_dequant(tensor: torch.fx.Node) -> bool:
return tensor.target in DQ_TARGETS
def is_groupwise_q_dq(tensor: torch.fx.Node) -> bool:
return tensor.target in [DQ_GROUP_TARGETS, Q_GROUP_TARGETS]