blob: d4770c1d0e0ad302149864412362ce34f7c570e7 [file] [log] [blame]
# Owner(s): ["oncall: quantization"]
import copy
import operator
import unittest
from typing import Any, List, Tuple
import torch
import torch._dynamo as torchdynamo
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization import (
FusedMovingAvgObsFakeQuantize,
observer,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
QConfigMapping,
)
from torch.ao.quantization._pt2e.quantizer import (
OperatorConfig,
QNNPackQuantizer,
Quantizer,
)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
prepare_pt2e_quantizer,
prepare_qat_pt2e_quantizer,
)
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
from torch.ao.quantization.qconfig import (
default_per_channel_symmetric_qnnpack_qat_qconfig,
default_per_channel_symmetric_qnnpack_qconfig,
default_symmetric_qnnpack_qat_qconfig,
)
from torch.ao.quantization.quantize_fx import (
convert_to_reference_fx,
prepare_fx,
prepare_qat_fx,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.ao.quantization.quantize_fx import _convert_to_reference_decomposed_fx
@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
def test_simple_quantizer(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
_DEFAULT_TARGET_DTYPE_INFO = {
"input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
"output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
}
for node in model.graph.nodes:
node.meta["target_dtype_info"] = copy.deepcopy(
_DEFAULT_TARGET_DTYPE_INFO
)
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.convolution.default
):
node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": observer.default_observer,
"weight_obs_or_fq_ctr": observer.default_weight_observer,
"bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
"output_act_obs_or_fq_ctr": observer.default_observer,
"weight_index": 1,
"bias_index": 2,
}
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
pass
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
m = prepare_pt2e_quantizer(m, BackendAQuantizer())
m(*example_inputs)
m = convert_pt2e(m)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_qnnpack_quantizer_conv(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
m = prepare_pt2e_quantizer(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_qnnpack_quantizer_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 16, bias=False)
self.linear2 = torch.nn.Linear(16, 8)
def forward(self, x):
return self.linear2(self.linear1(x))
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()
# Test with 2d inputs
example_inputs_2d = (torch.randn(9, 8),)
example_inputs_3d = (torch.randn(9, 10, 8),)
example_inputs_4d = (torch.randn(9, 10, 11, 8),)
for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]:
# program capture
m = m_eager
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
m = prepare_pt2e_quantizer(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
): 3,
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel
): 2,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_qnnpack_backend_config()
m_copy = copy.deepcopy(m_eager)
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
m_fx(*example_inputs)
m_fx = _convert_to_reference_decomposed_fx(
m_fx, backend_config=backend_config
)
m_fx, _ = torchdynamo.export(
m_fx,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
fx_quant_output = m_fx(*example_inputs)
self.assertTrue(torch.allclose(fx_quant_output, pt2_quant_output))
def test_qnnpack_quantizer_conv_linear_no_permute(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
self.linear1 = torch.nn.Linear(64, 8, bias=False)
self.linear2 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
reshape_out = torch.reshape(conv_out, (2, 64))
return self.linear2(self.linear1(reshape_out))
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()
# Test with 2d inputs
example_inputs = (torch.randn(2, 3, 4, 4),)
# program capture
m = m_eager
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
m = prepare_pt2e_quantizer(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 5,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 5,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 3,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_qnnpack_backend_config()
m_copy = copy.deepcopy(m_eager)
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
m_fx(*example_inputs)
m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
fx_quant_output = m_fx(*example_inputs)
self.assertTrue(torch.allclose(fx_quant_output, pt2_quant_output))
@unittest.skip(
"Skip due to linear traces into a different pattern. See test comment."
)
def test_qnnpack_quantizer_conv_linear(self):
"""
This test fails because linear decompositon changes due to the presence of
permute node. In the below linear 1 is decomposed as
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
%clone_default : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_default,), kwargs = {memory_format: torch.contiguous_format}) # noqa: B950
%_unsafe_view_default : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%clone_default, [8, 16]), kwargs = {}) # noqa: B950
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%_unsafe_view_default, %t_default), kwargs = {}) # noqa: B950
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mm_default, [2, 2, 2, 8]), kwargs = {}) # noqa: B950
Note the presence of cline and unsafe_view. This is due to permute
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3)
self.linear1 = torch.nn.Linear(16, 8, bias=False)
self.linear2 = torch.nn.Linear(8, 8)
def forward(self, x):
conv_out = self.conv(x)
permute_out = torch.permute(conv_out, (0, 2, 3, 1))
return self.linear2(self.linear1(permute_out))
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()
# Test with 2d inputs
example_inputs = (torch.randn(2, 3, 4, 4),)
# program capture
m = m_eager
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
m = prepare_pt2e_quantizer(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 2,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_qnnpack_backend_config()
m_copy = copy.deepcopy(m)
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
fx_quant_output = m_fx(*example_inputs)
self.assertTrue(torch.allclose(fx_quant_output, pt2_quant_output))
def test_qnnpack_quantizer_obs_sharing_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = torch.mean(x)
return x
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
m = prepare_pt2e_quantizer(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 5,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 5,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.mean.dim),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.hardtanh.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.mean.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_prepare_qat_conv_bn_fusion(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=False, has_relu=False)
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=True, has_relu=False)
def test_prepare_qat_conv_bn_relu_fusion(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=False, has_relu=True)
self._verify_symmetric_qnnpack_qat_graph(M(), example_inputs, is_per_channel=True, has_relu=True)
def _verify_symmetric_qnnpack_qat_graph(
self,
m: torch.fx.GraphModule,
example_inputs: Tuple[Any, ...],
is_per_channel: bool,
has_relu: bool,
):
"""
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
with fake quantizes inserted into the correct places.
# TODO: also verify that metadata is copied over to the new nodes.
"""
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
quantizer.set_global(qq.get_symmetric_quantization_config(is_per_channel, is_qat=True))
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
m = prepare_qat_pt2e_quantizer(m, quantizer)
m(*example_inputs)
# Verify: getitem output activation fake quantize
output_node = list(m.graph.nodes)[-1]
output_fq_node = output_node.args[0][0]
self.assertTrue(output_fq_node.target.startswith("activation_post_process_"))
output_fq_mod = getattr(m, output_fq_node.target)
self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize)
self.assertEqual(type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver)
self.assertEqual(output_fq_mod.dtype, torch.qint8)
self.assertEqual(output_fq_mod.quant_min, -128)
self.assertEqual(output_fq_mod.quant_max, 127)
# Verify: getitem(bn, 0) or relu(getitem(bn, 0))
if has_relu:
relu_node = output_fq_node.args[0]
getitem_node = relu_node.args[0]
self.assertEqual(relu_node.target, torch.ops.aten.relu.default)
else:
relu_node = None
getitem_node = output_fq_node.args[0]
bn_node = getitem_node.args[0]
self.assertEqual(getitem_node.target, operator.getitem)
self.assertEqual(bn_node.target, torch.ops.aten._native_batch_norm_legit.default)
# Verify: conv / scale_factor.reshape + bias.reshape
add_bias_node = bn_node.args[0]
(div_scale_factor_node, bias_reshape_node) = add_bias_node.args
(conv_node, scale_factor_reshape_node) = div_scale_factor_node.args
self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor)
self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor)
self.assertEqual(bias_reshape_node.target, torch.ops.aten.view.default)
self.assertEqual(conv_node.target, torch.ops.aten.convolution.default)
self.assertEqual(scale_factor_reshape_node.target, torch.ops.aten.view.default)
# Verify: conv input activation fake quantize
conv_input_fq_node = conv_node.args[0]
conv_input_node = conv_input_fq_node.args[0]
self.assertTrue(conv_input_fq_node.target.startswith("activation_post_process_"))
conv_input_fq_mod = getattr(m, conv_input_fq_node.target)
self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize)
self.assertEqual(type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver)
self.assertEqual(conv_input_fq_mod.dtype, torch.qint8)
self.assertEqual(conv_input_fq_mod.quant_min, -128)
self.assertEqual(conv_input_fq_mod.quant_max, 127)
self.assertTrue(conv_input_node.op, "placeholder")
# Verify: conv weight fake quantize
conv_weight_fq_node = conv_node.args[1]
self.assertTrue(conv_weight_fq_node.target.startswith("activation_post_process_"))
conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target)
if is_per_channel:
expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver
else:
expected_weight_observer_type = MovingAverageMinMaxObserver
self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize)
self.assertEqual(type(conv_weight_fq_mod.activation_post_process), expected_weight_observer_type)
self.assertEqual(conv_weight_fq_mod.dtype, torch.qint8)
self.assertEqual(conv_weight_fq_mod.quant_min, -127)
self.assertEqual(conv_weight_fq_mod.quant_max, 127)
# Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias)
zero_bias_node = conv_node.args[2]
mul_weight_scale_factor_node = conv_weight_fq_node.args[0]
(conv_weight_fq_node, scale_factor_reshape_node) = mul_weight_scale_factor_node.args
self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default)
self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor)
self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default)
self.assertEqual(scale_factor_reshape_node.target, torch.ops.aten.view.default)
# Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps)
scale_factor_node = scale_factor_reshape_node.args[0]
(bn_weight_node, sqrt_node) = scale_factor_node.args
bn_running_var_add_node = sqrt_node.args[0]
(bn_running_var_node, eps) = bn_running_var_add_node.args
self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor)
self.assertTrue("param_constant" in bn_weight_node.target)
self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default)
self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor)
self.assertTrue("tensor_constant" in bn_running_var_node.target)
self.assertEqual(eps, 1e-5)
def test_prepare_qat_conv_bn_numerics(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)
def test_prepare_qat_conv_bn_relu_numerics(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)
def _verify_symmetric_qnnpack_qat_numerics(
self,
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
is_per_channel: bool,
):
"""
Helper method to verify that the QAT numerics for PT2E quantization match those of
FX graph mode quantization for symmetric qnnpack.
"""
# PT2 export
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
model_pt2e = copy.deepcopy(model)
quantizer = QNNPackQuantizer()
quantizer.set_global(qq.get_symmetric_quantization_config(is_per_channel=is_per_channel, is_qat=True))
model_pt2e, guards = torchdynamo.export(
model_pt2e,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
result_pt2e = model_pt2e(*example_inputs)
# FX
# Note: In order to match the PT2E numerics exactly, we need to feed the
# example inputs to the model once before calling prepare, since this is
# what torchdynamo.export does. Otherwise, the BN running mean and variance
# would diverge in the two flows and this test would fail. For more detail,
# see https://github.com/pytorch/pytorch/issues/95900.
model_fx = copy.deepcopy(model)
model_fx(*example_inputs)
if is_per_channel:
default_qconfig = default_per_channel_symmetric_qnnpack_qat_qconfig
else:
default_qconfig = default_symmetric_qnnpack_qat_qconfig
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
backend_config = get_qnnpack_backend_config()
model_fx = prepare_qat_fx(model_fx, qconfig_mapping, example_inputs, backend_config=backend_config)
result_fx = model_fx(*example_inputs)
# Verify that numerics match
self.assertEqual(result_pt2e, result_fx)
class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_resnet18_with_quantizer_api(self):
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
)
before_fusion_result = m(*example_inputs)
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = prepare_pt2e_quantizer(m, quantizer)
# checking that we inserted observers correctly for maxpool operator (input and
# output share observer instance)
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
after_quant_result = m(*example_inputs)
# comparing with existing fx graph mode quantization reference flow
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_qnnpack_backend_config()
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
after_prepare_result_fx = m_fx(*example_inputs)
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
after_quant_result_fx = m_fx(*example_inputs)
# the result matches exactly after prepare
# Note: this currently will always be true since we are inserting observers
# the check becomes useful when we add qat examples
# but we can still manully inspect the printed observers to make sure
# it matches
self.assertEqual(after_prepare_result, after_prepare_result_fx)
self.assertEqual(
compute_sqnr(after_prepare_result, after_prepare_result_fx),
torch.tensor(float("inf")),
)
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
)
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)