[PT2E][Quant][BE] Split short term and long term tests in different files (#99065)
Just for better organization
Differential Revision: [D44918492](https://our.internmc.facebook.com/intern/diff/D44918492/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44918492/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99065
Approved by: https://github.com/jerryzh168
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index b99040c..913afab 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -1,153 +1,31 @@
# Owner(s): ["oncall: quantization"]
import copy
-import itertools
from typing import List
import torch
import torch._dynamo as torchdynamo
-import torch.nn as nn
-from torch._inductor.compile_fx import compile_fx
from torch.ao.ns.fx.utils import compute_sqnr
-from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
+from torch.ao.quantization import observer, QConfigMapping
from torch.ao.quantization._pt2e.quantizer import (
OperatorConfig,
QNNPackQuantizer,
Quantizer,
)
-from torch.ao.quantization._quantize_pt2e import (
- convert_pt2e,
- prepare_pt2e,
- prepare_pt2e_quantizer,
-)
+from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
-from torch.ao.quantization.backend_config._qnnpack_pt2e import (
- get_qnnpack_pt2e_backend_config,
-)
-from torch.ao.quantization.backend_config._x86_inductor_pt2e import (
- get_x86_inductor_pt2e_backend_config,
-)
-from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
-from torch.ao.quantization.quantize_fx import (
- convert_fx,
- convert_to_reference_fx,
- prepare_fx,
-)
+from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
- skipIfNoX86,
)
from torch.testing._internal.common_quantized import override_quantized_engine
@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
- def test_qconfig_none(self):
- class M(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = nn.Conv2d(1, 1, 1)
- self.conv2 = nn.Conv2d(1, 1, 1)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.conv2(x)
- return x
-
- with override_quantized_engine("qnnpack"):
- m = M().eval()
- example_inputs = (torch.randn(1, 1, 1, 1),)
- # program capture
- m, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
- tracing_mode="real",
- )
-
- qconfig = get_default_qconfig("qnnpack")
- qconfig_mapping = (
- QConfigMapping().set_global(qconfig).set_module_name("conv2", None)
- )
- backend_config = get_qnnpack_pt2e_backend_config()
- m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
- m(*example_inputs)
- m = convert_pt2e(m)
- m(*example_inputs)
-
- # first conv is quantized, second conv is not quantized
- node_occurrence = {
- # two for input of the first conv, one for output for the first conv
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor
- ): 3,
- }
- node_list = [
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
- ns.call_function(torch.ops.aten.convolution.default),
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
- ns.call_function(torch.ops.aten.convolution.default),
- ]
- self.checkGraphModuleNodes(
- m,
- expected_node_list=node_list,
- expected_node_occurrence=node_occurrence,
- )
-
- def test_qconfig_module_type(self):
- class M(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv = nn.Conv2d(1, 1, 1)
- self.linear = nn.Linear(9, 3)
-
- def forward(self, x):
- x = self.conv(x)
- x = x.reshape((1, -1))
- x = self.linear(x)
- return x
-
- with override_quantized_engine("qnnpack"):
- m = M().eval()
- example_inputs = (torch.randn(1, 1, 3, 3),)
-
- # program capture
- m, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
- tracing_mode="real",
- )
-
- qconfig = get_default_qconfig("qnnpack")
- qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
- backend_config = get_qnnpack_pt2e_backend_config()
- m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
- m(*example_inputs)
- m = convert_pt2e(m)
- m(*example_inputs)
- # conv is quantized, linear is not quantized
- node_occurrence = {
- # two for input and weight of the conv, one for output for the conv
- ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor
- ): 3,
- }
- node_list = [
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
- ns.call_function(torch.ops.aten.convolution.default),
- ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
- ns.call_function(torch.ops.aten.addmm.default),
- ]
- self.checkGraphModuleNodes(m, expected_node_list=node_list)
-
def test_simple_quantizer(self):
class M(torch.nn.Module):
def __init__(self):
@@ -332,312 +210,10 @@
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
- def test_rearrange_weight_observer_for_decomposed_linear(self):
- """
- Check whether weight observer is correctly rearranged for decomposed linear.
- before:
- weight - t - observer \
- input - observer - addmm/mm
- after:
- weight - observer - t \
- input - observer - addmm/mm
- """
-
- class M(torch.nn.Module):
- def __init__(self, with_bias, use_relu):
- super().__init__()
- self.linear = nn.Linear(4, 4, bias=with_bias)
- self.relu = nn.ReLU()
- self.use_relu = use_relu
-
- def forward(self, x):
- x = self.linear(x)
- return self.relu(x) if self.use_relu else x
-
- with_bias_list = [True, False]
- use_relu_list = [True, False]
- cases = itertools.product(with_bias_list, use_relu_list)
- for with_bias, use_relu in cases:
- m = M(with_bias, use_relu).eval()
- example_inputs = (torch.randn(1, 4),)
-
- # program capture
- m, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
- tracing_mode="real",
- )
-
- qconfig = get_default_qconfig("qnnpack")
- qconfig_mapping = QConfigMapping().set_global(qconfig)
- backend_config = get_qnnpack_pt2e_backend_config()
- m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
-
- # 1. Check graph nodes:
- # - args[0] of t should be the weight observer
- # - args[-1] of addmm/mm should be t
- error_msg = (
- "Weight observer is not correctly rearranged for decomposed linear"
- )
- for node in m.graph.nodes:
- if node.target == torch.ops.aten.t.default:
- target = node.args[0].target
- self.assertTrue(
- isinstance(getattr(m, target), observer.ObserverBase), error_msg
- )
- elif node.target in (
- torch.ops.aten.addmm.default,
- torch.ops.aten.mm.default,
- ):
- target = node.args[-1].target
- self.assertTrue(target == torch.ops.aten.t.default, error_msg)
-
- # 2. Check m.code to ensure `m.recompile()` is called.
- # If weight observer is rearranged in graph but `m.recompile()` is not called,
- # m.code would be wrong.
- code_before_recompile = m.code
- m.recompile()
- code_after_recompile = m.code
- self.assertTrue(code_before_recompile == code_after_recompile, error_msg)
-
- def test_transposed_conv_bn_fusion(self):
- class M(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.conv_trans = torch.nn.ConvTranspose2d(10, 20, 3)
- # channels for batchnorm is the same as the out_channels for convtranspose
- self.bn = torch.nn.BatchNorm2d(20)
-
- def forward(self, x):
- return self.bn(self.conv_trans(x))
-
- with override_quantized_engine("qnnpack"):
- m = M().eval()
- example_inputs = (torch.randn(10, 10, 10, 10),)
- # program capture
- m, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
- tracing_mode="real",
- )
-
- node_occurrence = {
- ns.call_function(torch.ops.aten.convolution.default): 1,
- ns.call_function(
- torch.ops.aten._native_batch_norm_legit_no_training.default
- ): 1,
- }
- self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
-
- qconfig = get_default_qconfig("qnnpack")
- qconfig_mapping = QConfigMapping().set_global(qconfig)
- backend_config = get_qnnpack_pt2e_backend_config()
- m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
- # make sure it runs
- m(*example_inputs)
-
- # make sure bn is fused into conv
- node_occurrence = {
- ns.call_function(torch.ops.aten.convolution.default): 1,
- ns.call_function(
- torch.ops.aten._native_batch_norm_legit_no_training.default
- ): 0,
- }
- self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
-
-
-@skipIfNoQNNPACK
-class TestQuantizePT2EX86Inductor(QuantizationTestCase):
- @skipIfNoX86
- def test_inductor_backend_config_conv(self):
- class M(torch.nn.Module):
- def __init__(self, use_relu: bool = False, inplace_relu: bool = False):
- super().__init__()
- self.use_relu = use_relu
- self.conv1 = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
- self.relu = nn.ReLU(inplace=inplace_relu)
-
- def forward(self, x):
- x = self.conv1(x)
- return self.relu(x) if self.use_relu else x
-
- use_relu_list = [True, False]
- inplace_relu_list = [True, False]
- with override_quantized_engine("x86"):
- with torch.no_grad():
- for use_relu, inplace_relu in itertools.product(
- use_relu_list, inplace_relu_list
- ):
- m = M(use_relu=use_relu, inplace_relu=inplace_relu).eval()
- example_inputs = (torch.randn(2, 3, 4, 4),)
- # program capture
- # **TODO** Add testcase for tracing_mode="symbolic" after fix issue:
- # https://github.com/pytorch/pytorch/issues/96274
- export_module, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
- tracing_mode="real",
- )
-
- qconfig = get_default_qconfig("x86")
- qconfig_mapping = QConfigMapping().set_global(qconfig)
- backend_config = get_x86_inductor_pt2e_backend_config()
- prepare_module = prepare_pt2e(
- export_module, qconfig_mapping, example_inputs, backend_config
- )
- prepare_module(*example_inputs)
- convert_module = convert_pt2e(prepare_module)
- convert_module(*example_inputs)
-
- # Fake quant should only be inserted at start and end
- node_occurrence = {
- # one for input and weight of the conv, one for output for the conv
- ns.call_function(
- torch.ops.quantized_decomposed.quantize_per_tensor
- ): 2,
- ns.call_function(
- torch.ops.quantized_decomposed.quantize_per_channel
- ): 1,
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_channel
- ): 1,
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor
- ): 2,
- }
- if use_relu:
- node_list = [
- ns.call_function(
- torch.ops.quantized_decomposed.quantize_per_tensor
- ),
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor
- ),
- ns.call_function(torch.ops.aten.convolution.default),
- ns.call_function(
- torch.ops.aten.relu_.default
- if inplace_relu
- else torch.ops.aten.relu.default
- ),
- ns.call_function(
- torch.ops.quantized_decomposed.quantize_per_tensor
- ),
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor
- ),
- ]
- else:
- node_list = [
- ns.call_function(
- torch.ops.quantized_decomposed.quantize_per_tensor
- ),
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor
- ),
- ns.call_function(torch.ops.aten.convolution.default),
- ns.call_function(
- torch.ops.quantized_decomposed.quantize_per_tensor
- ),
- ns.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor
- ),
- ]
- self.checkGraphModuleNodes(
- convert_module,
- expected_node_occurrence=node_occurrence,
- expected_node_list=node_list,
- )
-
- # Step1: Ref result in 1.X fx path
- backend_config_1_x = get_x86_backend_config()
- m_copy = copy.deepcopy(m)
- m_prepare_fx = prepare_fx(
- m_copy,
- qconfig_mapping,
- example_inputs,
- backend_config=backend_config_1_x,
- )
- after_prepare_result_fx = m_prepare_fx(*example_inputs)
- m_convert_fx = convert_fx(
- m_prepare_fx, backend_config=backend_config_1_x
- )
- ref_result = m_convert_fx(*example_inputs)
-
- # Step2: Start to lowering into Inductor
- run = compile_fx(convert_module, example_inputs)
- # Inductor first run
- inductor_res = run(*example_inputs)
- # Inductor second run
- inductor_res = run(*example_inputs)
- self.assertEqual(ref_result, inductor_res, atol=5e-2, rtol=5e-2)
-
class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
- def test_resnet18(self):
- import torchvision
-
- with override_quantized_engine("qnnpack"):
- example_inputs = (torch.randn(1, 3, 224, 224),)
- m = torchvision.models.resnet18().eval()
- m_copy = copy.deepcopy(m)
- # program capture
- m, guards = torchdynamo.export(
- m,
- *copy.deepcopy(example_inputs),
- aten_graph=True,
- tracing_mode="real",
- )
-
- backend_config = get_qnnpack_pt2e_backend_config()
- # TODO: define qconfig_mapping specifically for executorch
- qconfig = get_default_qconfig("qnnpack")
- qconfig_mapping = QConfigMapping().set_global(qconfig)
- before_fusion_result = m(*example_inputs)
-
- m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
-
- # checking that we inserted observers correctly for maxpool operator (input and
- # output share observer instance)
- self.assertEqual(
- id(m.activation_post_process_3), id(m.activation_post_process_2)
- )
- after_prepare_result = m(*example_inputs)
- m = convert_pt2e(m)
-
- after_quant_result = m(*example_inputs)
-
- # comparing with existing fx graph mode quantization reference flow
- backend_config = get_qnnpack_backend_config()
- m_fx = prepare_fx(
- m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
- )
- after_prepare_result_fx = m_fx(*example_inputs)
- m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
-
- after_quant_result_fx = m_fx(*example_inputs)
-
- # the result matches exactly after prepare
- self.assertEqual(after_prepare_result, after_prepare_result_fx)
- self.assertEqual(
- compute_sqnr(after_prepare_result, after_prepare_result_fx),
- torch.tensor(float("inf")),
- )
- # there are slight differences after convert due to different implementations
- # of quant/dequant
- self.assertTrue(
- torch.max(after_quant_result - after_quant_result_fx) < 1e-1
- )
- self.assertTrue(
- compute_sqnr(after_quant_result, after_quant_result_fx) > 35
- )
-
- @skip_if_no_torchvision
- @skipIfNoQNNPACK
def test_resnet18_with_quantizer_api(self):
import torchvision
diff --git a/test/quantization/pt2e/test_quantize_pt2e_fx.py b/test/quantization/pt2e/test_quantize_pt2e_fx.py
new file mode 100644
index 0000000..8b93190
--- /dev/null
+++ b/test/quantization/pt2e/test_quantize_pt2e_fx.py
@@ -0,0 +1,445 @@
+# Owner(s): ["oncall: quantization"]
+import copy
+import itertools
+
+import torch
+import torch._dynamo as torchdynamo
+import torch.nn as nn
+from torch._inductor.compile_fx import compile_fx
+from torch.ao.ns.fx.utils import compute_sqnr
+from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
+from torch.ao.quantization._quantize_pt2e import (
+ convert_pt2e,
+ prepare_pt2e,
+)
+from torch.ao.quantization.backend_config import get_qnnpack_backend_config
+from torch.ao.quantization.backend_config._qnnpack_pt2e import (
+ get_qnnpack_pt2e_backend_config,
+)
+from torch.ao.quantization.backend_config._x86_inductor_pt2e import (
+ get_x86_inductor_pt2e_backend_config,
+)
+from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
+from torch.ao.quantization.quantize_fx import (
+ convert_fx,
+ convert_to_reference_fx,
+ prepare_fx,
+)
+from torch.testing._internal.common_quantization import (
+ NodeSpec as ns,
+ QuantizationTestCase,
+ skip_if_no_torchvision,
+ skipIfNoQNNPACK,
+ skipIfNoX86,
+)
+from torch.testing._internal.common_quantized import override_quantized_engine
+
+
+@skipIfNoQNNPACK
+class TestQuantizePT2EFX(QuantizationTestCase):
+ def test_qconfig_none(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = nn.Conv2d(1, 1, 1)
+ self.conv2 = nn.Conv2d(1, 1, 1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+ with override_quantized_engine("qnnpack"):
+ m = M().eval()
+ example_inputs = (torch.randn(1, 1, 1, 1),)
+ # program capture
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+
+ qconfig = get_default_qconfig("qnnpack")
+ qconfig_mapping = (
+ QConfigMapping().set_global(qconfig).set_module_name("conv2", None)
+ )
+ backend_config = get_qnnpack_pt2e_backend_config()
+ m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
+ m(*example_inputs)
+ m = convert_pt2e(m)
+ m(*example_inputs)
+
+ # first conv is quantized, second conv is not quantized
+ node_occurrence = {
+ # two for input of the first conv, one for output for the first conv
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor
+ ): 3,
+ }
+ node_list = [
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
+ ns.call_function(torch.ops.aten.convolution.default),
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
+ ns.call_function(torch.ops.aten.convolution.default),
+ ]
+ self.checkGraphModuleNodes(
+ m,
+ expected_node_list=node_list,
+ expected_node_occurrence=node_occurrence,
+ )
+
+ def test_qconfig_module_type(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(1, 1, 1)
+ self.linear = nn.Linear(9, 3)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = x.reshape((1, -1))
+ x = self.linear(x)
+ return x
+
+ with override_quantized_engine("qnnpack"):
+ m = M().eval()
+ example_inputs = (torch.randn(1, 1, 3, 3),)
+
+ # program capture
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+
+ qconfig = get_default_qconfig("qnnpack")
+ qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
+ backend_config = get_qnnpack_pt2e_backend_config()
+ m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
+ m(*example_inputs)
+ m = convert_pt2e(m)
+ m(*example_inputs)
+ # conv is quantized, linear is not quantized
+ node_occurrence = {
+ # two for input and weight of the conv, one for output for the conv
+ ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor
+ ): 3,
+ }
+ node_list = [
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
+ ns.call_function(torch.ops.aten.convolution.default),
+ ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
+ ns.call_function(torch.ops.aten.addmm.default),
+ ]
+ self.checkGraphModuleNodes(m, expected_node_list=node_list)
+
+ def test_rearrange_weight_observer_for_decomposed_linear(self):
+ """
+ Check whether weight observer is correctly rearranged for decomposed linear.
+ before:
+ weight - t - observer \
+ input - observer - addmm/mm
+ after:
+ weight - observer - t \
+ input - observer - addmm/mm
+ """
+
+ class M(torch.nn.Module):
+ def __init__(self, with_bias, use_relu):
+ super().__init__()
+ self.linear = nn.Linear(4, 4, bias=with_bias)
+ self.relu = nn.ReLU()
+ self.use_relu = use_relu
+
+ def forward(self, x):
+ x = self.linear(x)
+ return self.relu(x) if self.use_relu else x
+
+ with_bias_list = [True, False]
+ use_relu_list = [True, False]
+ cases = itertools.product(with_bias_list, use_relu_list)
+ for with_bias, use_relu in cases:
+ m = M(with_bias, use_relu).eval()
+ example_inputs = (torch.randn(1, 4),)
+
+ # program capture
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+
+ qconfig = get_default_qconfig("qnnpack")
+ qconfig_mapping = QConfigMapping().set_global(qconfig)
+ backend_config = get_qnnpack_pt2e_backend_config()
+ m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
+
+ # 1. Check graph nodes:
+ # - args[0] of t should be the weight observer
+ # - args[-1] of addmm/mm should be t
+ error_msg = (
+ "Weight observer is not correctly rearranged for decomposed linear"
+ )
+ for node in m.graph.nodes:
+ if node.target == torch.ops.aten.t.default:
+ target = node.args[0].target
+ self.assertTrue(
+ isinstance(getattr(m, target), observer.ObserverBase), error_msg
+ )
+ elif node.target in (
+ torch.ops.aten.addmm.default,
+ torch.ops.aten.mm.default,
+ ):
+ target = node.args[-1].target
+ self.assertTrue(target == torch.ops.aten.t.default, error_msg)
+
+ # 2. Check m.code to ensure `m.recompile()` is called.
+ # If weight observer is rearranged in graph but `m.recompile()` is not called,
+ # m.code would be wrong.
+ code_before_recompile = m.code
+ m.recompile()
+ code_after_recompile = m.code
+ self.assertTrue(code_before_recompile == code_after_recompile, error_msg)
+
+ def test_transposed_conv_bn_fusion(self):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv_trans = torch.nn.ConvTranspose2d(10, 20, 3)
+ # channels for batchnorm is the same as the out_channels for convtranspose
+ self.bn = torch.nn.BatchNorm2d(20)
+
+ def forward(self, x):
+ return self.bn(self.conv_trans(x))
+
+ with override_quantized_engine("qnnpack"):
+ m = M().eval()
+ example_inputs = (torch.randn(10, 10, 10, 10),)
+ # program capture
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+
+ node_occurrence = {
+ ns.call_function(torch.ops.aten.convolution.default): 1,
+ ns.call_function(
+ torch.ops.aten._native_batch_norm_legit_no_training.default
+ ): 1,
+ }
+ self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
+
+ qconfig = get_default_qconfig("qnnpack")
+ qconfig_mapping = QConfigMapping().set_global(qconfig)
+ backend_config = get_qnnpack_pt2e_backend_config()
+ m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
+ # make sure it runs
+ m(*example_inputs)
+
+ # make sure bn is fused into conv
+ node_occurrence = {
+ ns.call_function(torch.ops.aten.convolution.default): 1,
+ ns.call_function(
+ torch.ops.aten._native_batch_norm_legit_no_training.default
+ ): 0,
+ }
+ self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
+
+
+@skipIfNoQNNPACK
+class TestQuantizePT2EFXX86Inductor(QuantizationTestCase):
+ @skipIfNoX86
+ def test_inductor_backend_config_conv(self):
+ class M(torch.nn.Module):
+ def __init__(self, use_relu: bool = False, inplace_relu: bool = False):
+ super().__init__()
+ self.use_relu = use_relu
+ self.conv1 = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1))
+ self.relu = nn.ReLU(inplace=inplace_relu)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ return self.relu(x) if self.use_relu else x
+
+ use_relu_list = [True, False]
+ inplace_relu_list = [True, False]
+ with override_quantized_engine("x86"):
+ with torch.no_grad():
+ for use_relu, inplace_relu in itertools.product(
+ use_relu_list, inplace_relu_list
+ ):
+ m = M(use_relu=use_relu, inplace_relu=inplace_relu).eval()
+ example_inputs = (torch.randn(2, 3, 4, 4),)
+ # program capture
+ # **TODO** Add testcase for tracing_mode="symbolic" after fix issue:
+ # https://github.com/pytorch/pytorch/issues/96274
+ export_module, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+
+ qconfig = get_default_qconfig("x86")
+ qconfig_mapping = QConfigMapping().set_global(qconfig)
+ backend_config = get_x86_inductor_pt2e_backend_config()
+ prepare_module = prepare_pt2e(
+ export_module, qconfig_mapping, example_inputs, backend_config
+ )
+ prepare_module(*example_inputs)
+ convert_module = convert_pt2e(prepare_module)
+ convert_module(*example_inputs)
+
+ # Fake quant should only be inserted at start and end
+ node_occurrence = {
+ # one for input and weight of the conv, one for output for the conv
+ ns.call_function(
+ torch.ops.quantized_decomposed.quantize_per_tensor
+ ): 2,
+ ns.call_function(
+ torch.ops.quantized_decomposed.quantize_per_channel
+ ): 1,
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_channel
+ ): 1,
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor
+ ): 2,
+ }
+ if use_relu:
+ node_list = [
+ ns.call_function(
+ torch.ops.quantized_decomposed.quantize_per_tensor
+ ),
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor
+ ),
+ ns.call_function(torch.ops.aten.convolution.default),
+ ns.call_function(
+ torch.ops.aten.relu_.default
+ if inplace_relu
+ else torch.ops.aten.relu.default
+ ),
+ ns.call_function(
+ torch.ops.quantized_decomposed.quantize_per_tensor
+ ),
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor
+ ),
+ ]
+ else:
+ node_list = [
+ ns.call_function(
+ torch.ops.quantized_decomposed.quantize_per_tensor
+ ),
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor
+ ),
+ ns.call_function(torch.ops.aten.convolution.default),
+ ns.call_function(
+ torch.ops.quantized_decomposed.quantize_per_tensor
+ ),
+ ns.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor
+ ),
+ ]
+ self.checkGraphModuleNodes(
+ convert_module,
+ expected_node_occurrence=node_occurrence,
+ expected_node_list=node_list,
+ )
+
+ # Step1: Ref result in 1.X fx path
+ backend_config_1_x = get_x86_backend_config()
+ m_copy = copy.deepcopy(m)
+ m_prepare_fx = prepare_fx(
+ m_copy,
+ qconfig_mapping,
+ example_inputs,
+ backend_config=backend_config_1_x,
+ )
+ after_prepare_result_fx = m_prepare_fx(*example_inputs)
+ m_convert_fx = convert_fx(
+ m_prepare_fx, backend_config=backend_config_1_x
+ )
+ ref_result = m_convert_fx(*example_inputs)
+
+ # Step2: Start to lowering into Inductor
+ run = compile_fx(convert_module, example_inputs)
+ # Inductor first run
+ inductor_res = run(*example_inputs)
+ # Inductor second run
+ inductor_res = run(*example_inputs)
+ self.assertEqual(ref_result, inductor_res, atol=5e-2, rtol=5e-2)
+
+
+class TestQuantizePT2EFXModels(QuantizationTestCase):
+ @skip_if_no_torchvision
+ @skipIfNoQNNPACK
+ def test_resnet18(self):
+ import torchvision
+
+ with override_quantized_engine("qnnpack"):
+ example_inputs = (torch.randn(1, 3, 224, 224),)
+ m = torchvision.models.resnet18().eval()
+ m_copy = copy.deepcopy(m)
+ # program capture
+ m, guards = torchdynamo.export(
+ m,
+ *copy.deepcopy(example_inputs),
+ aten_graph=True,
+ tracing_mode="real",
+ )
+
+ backend_config = get_qnnpack_pt2e_backend_config()
+ # TODO: define qconfig_mapping specifically for executorch
+ qconfig = get_default_qconfig("qnnpack")
+ qconfig_mapping = QConfigMapping().set_global(qconfig)
+ before_fusion_result = m(*example_inputs)
+
+ m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
+
+ # checking that we inserted observers correctly for maxpool operator (input and
+ # output share observer instance)
+ self.assertEqual(
+ id(m.activation_post_process_3), id(m.activation_post_process_2)
+ )
+ after_prepare_result = m(*example_inputs)
+ m = convert_pt2e(m)
+
+ after_quant_result = m(*example_inputs)
+
+ # comparing with existing fx graph mode quantization reference flow
+ backend_config = get_qnnpack_backend_config()
+ m_fx = prepare_fx(
+ m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
+ )
+ after_prepare_result_fx = m_fx(*example_inputs)
+ m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
+
+ after_quant_result_fx = m_fx(*example_inputs)
+
+ # the result matches exactly after prepare
+ self.assertEqual(after_prepare_result, after_prepare_result_fx)
+ self.assertEqual(
+ compute_sqnr(after_prepare_result, after_prepare_result_fx),
+ torch.tensor(float("inf")),
+ )
+ # there are slight differences after convert due to different implementations
+ # of quant/dequant
+ self.assertTrue(
+ torch.max(after_quant_result - after_quant_result_fx) < 1e-1
+ )
+ self.assertTrue(
+ compute_sqnr(after_quant_result, after_quant_result_fx) > 35
+ )