blob: 913afab55d033d5f8a5bbbf3d3e7d89a0e7d48d6 [file] [log] [blame]
# Owner(s): ["oncall: quantization"]
import copy
from typing import List
import torch
import torch._dynamo as torchdynamo
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization import observer, QConfigMapping
from torch.ao.quantization._pt2e.quantizer import (
OperatorConfig,
QNNPackQuantizer,
Quantizer,
)
from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import override_quantized_engine
@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
def test_simple_quantizer(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
_DEFAULT_TARGET_DTYPE_INFO = {
"input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
"output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
}
for node in model.graph.nodes:
node.meta["target_dtype_info"] = copy.deepcopy(
_DEFAULT_TARGET_DTYPE_INFO
)
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.convolution.default
):
node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": observer.default_observer,
"weight_obs_or_fq_ctr": observer.default_weight_observer,
"bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
"output_act_obs_or_fq_ctr": observer.default_observer,
"weight_index": 1,
"bias_index": 2,
}
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
pass
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
m = prepare_pt2e_quantizer(m, BackendAQuantizer())
m(*example_inputs)
m = convert_pt2e(m)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_qnnpack_quantizer_conv(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = (
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
)
quantizer.set_global(operator_config)
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
m = prepare_pt2e_quantizer(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_qnnpack_quantizer_obs_sharing_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = torch.mean(x)
return x
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = (
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
)
quantizer.set_global(operator_config)
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
m = prepare_pt2e_quantizer(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 5,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 5,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.mean.dim),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.hardtanh.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.mean.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_resnet18_with_quantizer_api(self):
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
before_fusion_result = m(*example_inputs)
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = (
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
)
quantizer.set_global(operator_config)
m = prepare_pt2e_quantizer(m, quantizer)
# checking that we inserted observers correctly for maxpool operator (input and
# output share observer instance)
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
after_quant_result = m(*example_inputs)
# comparing with existing fx graph mode quantization reference flow
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_qnnpack_backend_config()
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
after_prepare_result_fx = m_fx(*example_inputs)
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
after_quant_result_fx = m_fx(*example_inputs)
# the result matches exactly after prepare
# Note: this currently will always be true since we are inserting observers
# the check becomes useful when we add qat examples
# but we can still manully inspect the printed observers to make sure
# it matches
self.assertEqual(after_prepare_result, after_prepare_result_fx)
self.assertEqual(
compute_sqnr(after_prepare_result, after_prepare_result_fx),
torch.tensor(float("inf")),
)
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
)
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)