| # Owner(s): ["oncall: quantization"] |
| |
| from collections import OrderedDict |
| import contextlib |
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| import torch.ao.nn.quantized as nnq |
| import torch.ao.nn.quantized.reference as nnqr |
| import torch.ao.nn.quantized.dynamic as nnqd |
| import torch.ao.nn.intrinsic as nni |
| import torch.ao.nn.intrinsic.quantized as nniq |
| import torch.ao.nn.intrinsic.quantized.dynamic as nniqd |
| import torch.multiprocessing as mp |
| from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY |
| |
| # graph mode quantization based on fx |
| from torch.ao.quantization.quantize_fx import ( |
| prepare_fx, |
| convert_fx, |
| convert_to_reference_fx, |
| _convert_to_reference_decomposed_fx, |
| prepare_qat_fx, |
| fuse_fx, |
| ) |
| |
| |
| from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler |
| |
| from torch.ao.quantization.fx.match_utils import ( |
| _is_match, |
| MatchAllNode, |
| ) |
| |
| from torch.ao.quantization import ( |
| QuantType, |
| ) |
| from torch.ao.quantization.quant_type import _get_quant_type_to_str |
| |
| from torch.ao.quantization import ( |
| QuantStub, |
| DeQuantStub, |
| QuantWrapper, |
| default_qconfig, |
| default_dynamic_qconfig, |
| default_per_channel_qconfig, |
| default_qat_qconfig, |
| default_reuse_input_qconfig, |
| default_symmetric_qnnpack_qconfig, |
| default_symmetric_qnnpack_qat_qconfig, |
| per_channel_dynamic_qconfig, |
| float16_dynamic_qconfig, |
| float16_static_qconfig, |
| float_qparams_weight_only_qconfig, |
| float_qparams_weight_only_qconfig_4bit, |
| get_default_qconfig, |
| get_default_qat_qconfig, |
| get_default_qconfig_mapping, |
| get_default_qat_qconfig_mapping, |
| fuse_modules, |
| fuse_modules_qat, |
| prepare, |
| prepare_qat, |
| convert, |
| quantize_dynamic, |
| default_placeholder_observer, |
| default_weight_observer, |
| PerChannelMinMaxObserver, |
| FixedQParamsFakeQuantize, |
| FixedQParamsObserver, |
| FusedMovingAvgObsFakeQuantize, |
| FakeQuantize, |
| MovingAverageMinMaxObserver, |
| HistogramObserver, |
| ReuseInputObserver, |
| QConfig, |
| default_embedding_qat_qconfig, |
| ) |
| |
| from torch.ao.quantization.backend_config import ( |
| get_fbgemm_backend_config, |
| get_qnnpack_backend_config, |
| BackendConfig, |
| BackendPatternConfig, |
| DTypeConfig, |
| DTypeWithConstraints, |
| ObservationType |
| ) |
| from torch.ao.quantization.backend_config.native import ( |
| get_test_only_legacy_native_backend_config, |
| ) |
| |
| from torch.ao.quantization.qconfig_mapping import ( |
| _get_symmetric_qnnpack_qconfig_mapping, |
| _get_symmetric_qnnpack_qat_qconfig_mapping, |
| _GLOBAL_DICT_KEY, |
| _MODULE_NAME_DICT_KEY, |
| _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, |
| _MODULE_NAME_REGEX_DICT_KEY, |
| _OBJECT_TYPE_DICT_KEY, |
| QConfigMapping, |
| ) |
| |
| from torch.ao.quantization.fx.qconfig_mapping_utils import ( |
| _get_object_type_qconfig, |
| _get_module_name_qconfig, |
| _get_module_name_regex_qconfig, |
| _maybe_adjust_qconfig_for_module_name_object_type_order, |
| ) |
| |
| from torch.ao.quantization.fx.pattern_utils import ( |
| _DEFAULT_FUSION_PATTERNS, |
| _DEFAULT_QUANTIZATION_PATTERNS, |
| _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP, |
| _DEFAULT_OUTPUT_OBSERVER_MAP, |
| _register_fusion_pattern, |
| _register_quant_pattern, |
| get_default_output_activation_post_process_map |
| ) |
| |
| from torch.ao.quantization.fx.custom_config import ( |
| STANDALONE_MODULE_NAME_DICT_KEY, |
| STANDALONE_MODULE_CLASS_DICT_KEY, |
| FLOAT_TO_OBSERVED_DICT_KEY, |
| OBSERVED_TO_QUANTIZED_DICT_KEY, |
| NON_TRACEABLE_MODULE_NAME_DICT_KEY, |
| NON_TRACEABLE_MODULE_CLASS_DICT_KEY, |
| INPUT_QUANTIZED_INDEXES_DICT_KEY, |
| OUTPUT_QUANTIZED_INDEXES_DICT_KEY, |
| PRESERVED_ATTRIBUTES_DICT_KEY, |
| FuseCustomConfig, |
| ConvertCustomConfig, |
| PrepareCustomConfig, |
| StandaloneModuleConfigEntry, |
| ) |
| import torch.ao.quantization.fx.lstm_utils |
| |
| from torch.ao.quantization.fx.utils import ( |
| _reroute_tuple_getitem_pattern, |
| NodeInfo, |
| ) |
| |
| from torch.ao.quantization.fake_quantize import ( |
| default_fixed_qparams_range_0to1_fake_quant, |
| default_fixed_qparams_range_neg1to1_fake_quant, |
| ) |
| |
| from torch.ao.quantization.observer import ( |
| default_fixed_qparams_range_0to1_observer, |
| default_fixed_qparams_range_neg1to1_observer, |
| MinMaxObserver, |
| _is_activation_post_process, |
| ) |
| |
| # test utils |
| from hypothesis import given, settings |
| from hypothesis import strategies as st |
| from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA |
| from torch.testing._internal.common_quantization import ( |
| LinearReluLinearModel, |
| LinearReluModel, |
| LinearBnLeakyReluModel, |
| LinearTanhModel, |
| ConvBnAddReluModel, |
| QuantizationTestCase, |
| skipIfNoFBGEMM, |
| skipIfNoQNNPACK, |
| skip_if_no_torchvision, |
| train_one_epoch, |
| run_ddp, |
| test_only_eval_fn, |
| test_only_train_fn, |
| ModelForConvTransposeBNFusion, |
| get_supported_device_types, |
| skipIfNoONEDNN, |
| ) |
| |
| from torch.testing._internal.common_quantization import ( |
| LinearModelWithSubmodule, |
| ResNetBase, |
| RNNDynamicModel, |
| RNNCellDynamicModel, |
| ) |
| |
| from torch.testing._internal.common_quantized import ( |
| supported_qengines, |
| override_qengines, |
| override_quantized_engine, |
| ) |
| |
| from torch.testing._internal.common_utils import ( |
| TemporaryFileName, |
| IS_ARM64, |
| ) |
| |
| from torch.testing._internal.common_quantization import NodeSpec as ns |
| |
| from torch.testing import FileCheck |
| |
| import copy |
| import itertools |
| import operator |
| import unittest |
| import io |
| from typing import Callable, Optional, List, Tuple |
| |
| class BinaryOp(torch.nn.Module): |
| def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): |
| """ ibinary_op means inplace binary op |
| """ |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1).float() |
| self.conv2 = torch.nn.Conv2d(1, 1, 1).float() |
| self.is_scalar = is_scalar |
| self.op = ibinary_op if ibinary_op and is_inplace else binary_op |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = 3 if self.is_scalar else self.conv2(y) |
| # x = x + y |
| x = self.op(x, y) |
| # x = y + x |
| x = self.op(y, x) |
| return x |
| |
| class BinaryOpNonQuantizedInput(torch.nn.Module): |
| def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): |
| """ ibinary_op means inplace binary op |
| """ |
| super().__init__() |
| self.is_scalar = is_scalar |
| self.op = ibinary_op if ibinary_op and is_inplace else binary_op |
| |
| def forward(self, x, y): |
| y = 3 if self.is_scalar else y |
| x = self.op(x, y) |
| return x |
| |
| class BinaryOpRelu(torch.nn.Module): |
| def __init__(self, binary_op, ibinary_op, is_inplace, relu_callable, |
| is_scalar): |
| """ ibinary_op means inplace binary op |
| """ |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1).float() |
| self.conv2 = torch.nn.Conv2d(1, 1, 1).float() |
| self.op = ibinary_op if ibinary_op and is_inplace else binary_op |
| self.relu_callable = relu_callable |
| self.is_scalar = is_scalar |
| if relu_callable is torch.nn.ReLU: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = relu_callable |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = 3 if self.is_scalar else self.conv2(y) |
| x = self.op(x, y) |
| x = self.relu(x) |
| x = self.op(y, x) |
| x = self.relu(x) |
| return x |
| |
| @torch.fx.wrap |
| def _user_func_with_complex_return_type(x): |
| return list(torch.split(x, 1, 1)) |
| |
| class TestFuseFx(QuantizationTestCase): |
| def test_fuse_conv_bn_relu(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1d = nn.Conv1d(1, 1, 1) |
| self.conv2d = nn.Conv2d(1, 1, 1) |
| self.conv3d = nn.Conv3d(1, 1, 1) |
| self.bn1d = nn.BatchNorm1d(1) |
| self.bn2d = nn.BatchNorm2d(1) |
| self.bn3d = nn.BatchNorm3d(1) |
| self.conv1d2 = nn.Conv1d(1, 1, 1) |
| self.conv2d2 = nn.Conv2d(1, 1, 1) |
| self.conv3d2 = nn.Conv3d(1, 1, 1) |
| self.bn1d2 = nn.BatchNorm1d(1) |
| self.bn2d2 = nn.BatchNorm2d(1) |
| self.bn3d2 = nn.BatchNorm3d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1d(x) |
| x = self.bn1d(x) |
| x = self.conv2d(x) |
| x = self.bn2d(x) |
| x = self.conv3d(x) |
| x = self.bn3d(x) |
| x = self.conv1d2(x) |
| x = self.bn1d2(x) |
| x = self.relu(x) |
| x = self.conv2d2(x) |
| x = self.bn2d2(x) |
| x = self.relu(x) |
| x = self.conv3d2(x) |
| x = self.bn3d2(x) |
| x = self.relu(x) |
| return x |
| |
| # test train mode |
| m = M().train() |
| # currently we don't check if the module are configured with qconfig before fusion |
| # TODO: if we decide to do that in the future, this test needs to |
| # be updated |
| # train mode fuse_fx is called in prepare_qat_fx |
| m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),)) |
| expected_nodes = [ |
| ns.call_module(nni.ConvBn1d), |
| ns.call_module(nni.ConvBn2d), |
| ns.call_module(nni.ConvBn3d), |
| ns.call_module(nni.ConvBnReLU1d), |
| ns.call_module(nni.ConvBnReLU2d), |
| ns.call_module(nni.ConvBnReLU3d), |
| ] |
| expected_occurrence = { |
| ns.call_module(nn.ReLU): 0 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| # test eval mode |
| m = M().eval() |
| # fuse_fx is a top level api and only supports eval mode |
| m = fuse_fx(m) |
| expected_nodes = [ |
| ns.call_module(nn.Conv1d), |
| ns.call_module(nn.Conv2d), |
| ns.call_module(nn.Conv3d), |
| ns.call_module(nni.ConvReLU1d), |
| ns.call_module(nni.ConvReLU2d), |
| ns.call_module(nni.ConvReLU3d), |
| ] |
| # ConvBnRelu1d is not fused |
| expected_occurrence = { |
| ns.call_module(nn.ReLU): 0 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| def test_fuse_linear_bn_eval(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(1, 1) |
| self.bn1d = nn.BatchNorm1d(1) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.bn1d(x) |
| return x |
| |
| # test eval mode |
| m = M().eval() |
| # fuse_fx is a top level api and only supports eval mode |
| m = fuse_fx(m) |
| expected_nodes = [ |
| ns.call_module(nn.Linear), |
| ] |
| expected_occurrence = { |
| ns.call_module(nn.BatchNorm1d): 0, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| @skipIfNoONEDNN |
| def test_fuse_linear_bn_leaky_relu_onednn(self): |
| # linear - bn - leaky_relu is fused for onednn backend only |
| from torch.ao.quantization.backend_config import get_onednn_backend_config |
| expected_nodes = [ |
| ns.call_module(nni.LinearLeakyReLU), |
| ] |
| expected_occurrence = { |
| ns.call_module(nn.BatchNorm1d): 0, |
| ns.call_module(nn.LeakyReLU): 0, |
| } |
| |
| for with_bn in [True, False]: |
| # test eval mode |
| m = LinearBnLeakyReluModel(with_bn).eval() |
| # fuse_fx is a top level api and only supports eval mode |
| m = fuse_fx(m, |
| backend_config=get_onednn_backend_config()) |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| def test_linear_bn_leaky_relu_not_fused_by_default(self): |
| # Make sure linear - bn - leaky_relu is not fused by default |
| for with_bn in [True, False]: |
| # test eval mode |
| m = LinearBnLeakyReluModel(with_bn).eval() |
| # fuse_fx is a top level api and only supports eval mode |
| m = fuse_fx(m) |
| expected_nodes = [ |
| ns.call_module(nn.Linear), |
| ns.call_module(nn.LeakyReLU), |
| ] |
| expected_occurrence = { |
| ns.call_module(nni.LinearLeakyReLU): 0, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| @skipIfNoONEDNN |
| def test_fuse_linear_tanh_for_onednn_backend(self): |
| # linear - tanh is fused for onednn backend only |
| from torch.ao.quantization.backend_config import get_onednn_backend_config |
| expected_nodes = [ |
| ns.call_module(nni.LinearTanh), |
| ] |
| expected_occurrence = { |
| ns.call_module(nn.Linear): 0, |
| ns.call_module(nn.Tanh): 0, |
| } |
| |
| # test eval mode |
| m = LinearTanhModel().eval() |
| # fuse_fx is a top level api and only supports eval mode |
| m = fuse_fx(m, |
| backend_config=get_onednn_backend_config()) |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| def test_linear_tanh_not_fused_by_default(self): |
| # Make sure linear - tanh is not fused by default |
| # test eval mode |
| m = LinearTanhModel().eval() |
| # fuse_fx is a top level api and only supports eval mode |
| m = fuse_fx(m) |
| expected_nodes = [ |
| ns.call_module(nn.Linear), |
| ns.call_module(nn.Tanh), |
| ] |
| expected_occurrence = { |
| ns.call_module(nni.LinearTanh): 0, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| def test_fuse_conv_bn_add_relu_onednn(self): |
| # conv - bn - add - relu is fused for onednn backend only |
| from torch.ao.quantization.backend_config import get_onednn_backend_config |
| options = itertools.product( |
| [True, False], # with_bn |
| [True, False], # with_relu |
| [True, False], # conv in the left |
| [True, False], # with_two_conv |
| [True, False], # use_torch_add |
| ) |
| for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: |
| expected_nodes = [ |
| ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d), |
| ] |
| expected_occurrence = { |
| ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d): 1, |
| ns.call_module(nn.BatchNorm2d): 0, |
| } |
| |
| # test eval mode |
| m = ConvBnAddReluModel( |
| with_bn=with_bn, |
| with_relu=with_relu, |
| left_conv=left_conv, |
| two_conv=two_conv, |
| use_torch_add=use_torch_add).eval() |
| |
| m = fuse_fx(m, |
| backend_config=get_onednn_backend_config()) |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| def test_fuse_conv_bn_add_relu_by_default(self): |
| options = itertools.product( |
| [True, False], # with_bn |
| [True, False], # with_relu |
| [True, False], # conv in the left |
| [True, False], # with_two_conv |
| [True, False], # use_torch_add |
| ) |
| for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: |
| # test eval mode |
| expected_nodes = [ |
| ns.call_module(nn.Conv2d), |
| ] |
| expected_occurrence = { |
| ns.call_module(nni.ConvAdd2d): 0, |
| } |
| m = ConvBnAddReluModel( |
| with_bn=with_bn, |
| with_relu=with_relu, |
| left_conv=left_conv, |
| two_conv=two_conv, |
| use_torch_add=use_torch_add).eval() |
| m = fuse_fx(m) |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| @skipIfNoONEDNN |
| def test_fuse_conv_bn_add_relu_lowering(self): |
| """ Test fusion and lowering of Conv2d - (bn -) ReLU |
| by FX. For onednn backedn only. |
| """ |
| from torch.ao.quantization.backend_config import get_onednn_backend_config |
| qconfig_mapping = get_default_qconfig_mapping('onednn') |
| with override_quantized_engine('onednn'): |
| options = itertools.product( |
| [True, False], # with_bn |
| [True, False], # with_relu |
| [True, False], # conv in the left |
| [True, False], # two_conv |
| [True, False], # use_torch_add |
| ) |
| for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1 if two_conv else 2, |
| ns.call_method("dequantize"): 1, |
| ns.call_module(nniq.ConvAddReLU2d if with_relu else nniq.ConvAdd2d): 1, |
| ns.call_module(nn.Conv2d): 0, |
| ns.call_module(nn.ReLU): 0, |
| } |
| node_occurrence_ref = { |
| ns.call_function(torch.quantize_per_tensor): 3, |
| ns.call_method("dequantize"): 3, |
| } |
| |
| # test eval mode |
| m = ConvBnAddReluModel( |
| with_bn=with_bn, |
| with_relu=with_relu, |
| left_conv=left_conv, |
| two_conv=two_conv, |
| use_torch_add=use_torch_add).eval() |
| example_x = m.get_example_inputs() |
| m = prepare_fx(m, qconfig_mapping, |
| example_inputs=example_x, |
| backend_config=get_onednn_backend_config()) |
| m_copy = copy.deepcopy(m) |
| m = convert_fx(m, backend_config=get_onednn_backend_config()) |
| m_ref = convert_to_reference_fx(m_copy) |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) |
| m(*example_x) |
| |
| def test_fuse_convtranspose_bn_eval(self): |
| |
| m = ModelForConvTransposeBNFusion().eval() |
| m = fuse_fx(m) |
| |
| expected_nodes = [ |
| ns.call_module(nn.ConvTranspose1d), |
| ns.call_module(nn.ConvTranspose2d), |
| ns.call_module(nn.ConvTranspose3d), |
| ] |
| expected_occurrence = { |
| ns.call_module(nn.BatchNorm1d): 0, |
| ns.call_module(nn.BatchNorm2d): 0, |
| ns.call_module(nn.BatchNorm3d): 0, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| |
| def test_fuse_module_relu(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1d = nn.Conv1d(1, 1, 1) |
| self.conv2d = nn.Conv2d(1, 1, 1) |
| self.conv3d = nn.Conv3d(1, 1, 1) |
| self.bn1d = nn.BatchNorm1d(1) |
| self.bn2d = nn.BatchNorm2d(1) |
| self.bn3d = nn.BatchNorm3d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1d(x) |
| x = self.relu(x) |
| x = self.conv2d(x) |
| x = self.relu(x) |
| x = self.conv3d(x) |
| x = self.relu(x) |
| x = self.bn1d(x) |
| x = self.relu(x) |
| x = self.bn2d(x) |
| x = self.relu(x) |
| x = self.bn3d(x) |
| x = self.relu(x) |
| return x |
| |
| m = M().eval() |
| m = fuse_fx(m) |
| expected_nodes = [ |
| ns.call_module(nni.ConvReLU1d), |
| ns.call_module(nni.ConvReLU2d), |
| ns.call_module(nni.ConvReLU3d), |
| ns.call_module(nni.BNReLU2d), |
| ns.call_module(nni.BNReLU3d), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) |
| |
| @skipIfNoFBGEMM |
| def test_qconfig_fused_module(self): |
| """ TODO: add test for all fused modules |
| """ |
| qconfig_dict = { |
| "": None, |
| "object_type": [(nn.Linear, default_qconfig), |
| (nn.ReLU, default_qconfig), |
| (F.relu, default_qconfig)] |
| } |
| |
| linearRelu_node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nniq.LinearReLU), |
| ns.call_method('dequantize') |
| ] |
| |
| linearReluLinear_node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nniq.LinearReLU), |
| ns.call_module(nnq.Linear), |
| ns.call_method('dequantize') |
| ] |
| |
| tests = [(LinearReluModel, linearRelu_node_list), |
| (LinearReluLinearModel, linearReluLinear_node_list)] |
| |
| for M, node_list in tests: |
| m = M().eval() |
| example_inputs = (torch.rand(5, 5),) |
| prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| |
| prepared(*example_inputs) |
| quantized = convert_fx(prepared) |
| |
| self.checkGraphModuleNodes(quantized, expected_node_list=node_list) |
| |
| def test_problematic_fuse_example(self): |
| class LinearRelu(nn.Sequential): |
| def __init__(self): |
| super().__init__( |
| nn.Linear(5, 5), |
| nn.ReLU(), |
| ) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lin_relu = LinearRelu() |
| self.linear = nn.Linear(5, 5) |
| |
| def forward(self, x): |
| x = self.lin_relu(x) |
| x = self.linear(x) |
| return x |
| |
| model = M().eval() |
| # these qconfigs somehow fail equality where default_qconfig does not |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (torch.nn.Linear, get_default_qconfig('fbgemm')), |
| (torch.nn.ReLU, get_default_qconfig('fbgemm')), |
| ], |
| } |
| m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) |
| |
| self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.ao.nn.intrinsic.modules.fused.LinearReLU)) |
| |
| @unittest.skip("Temporarily skipping the test case, will enable after the simple" |
| "pattern format is supported") |
| def test_fuse_addtional_fuser_method(self): |
| class MyConvReLU(torch.nn.Module): |
| pass |
| |
| def my_conv_relu_fuser(conv, relu): |
| return MyConvReLU() |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(self.conv(x)) |
| |
| m = M().eval() |
| m = fuse_fx(m, fuse_custom_config={ |
| "additional_fuser_method_mapping": { |
| (torch.nn.Conv2d, torch.nn.ReLU): my_conv_relu_fuser |
| } |
| }) |
| self.checkGraphModuleNodes(m, expected_node=ns.call_module(MyConvReLU)) |
| |
| def test_fuse_custom_pattern(self): |
| class M(torch.nn.Module): |
| def __init__(self, use_torch_add=True): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.bn = torch.nn.BatchNorm2d(3) |
| self.relu = torch.nn.ReLU() |
| self.maxpool = torch.nn.MaxPool2d(3) |
| if use_torch_add: |
| self.add = torch.add |
| else: |
| self.add = operator.add |
| |
| def forward(self, x): |
| y = x |
| y = self.maxpool(x) |
| x = self.conv(x) |
| x = self.bn(x) |
| x = self.add(y, x) |
| x = self.relu(x) |
| return x |
| |
| for use_torch_add in [True, False]: |
| m = M(use_torch_add).eval() |
| |
| def fuse_conv_bn_relu(is_qat, relu, add_pattern): |
| _, _, bn_pattern = add_pattern |
| bn, conv = bn_pattern |
| return conv |
| |
| conv_bn_res_relu_config1 = BackendPatternConfig() \ |
| ._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \ |
| .set_fuser_method(fuse_conv_bn_relu) |
| conv_bn_res_relu_config2 = BackendPatternConfig() \ |
| ._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \ |
| .set_fuser_method(fuse_conv_bn_relu) |
| backend_config = BackendConfig() \ |
| .set_backend_pattern_config(conv_bn_res_relu_config1) \ |
| .set_backend_pattern_config(conv_bn_res_relu_config2) |
| m = fuse_fx(m, backend_config=backend_config) |
| self.assertEqual(type(m.conv), torch.nn.Conv2d) |
| # check bn and relu are gone since we replaced the whole pattern to conv |
| self.assertFalse(hasattr(m, "bn")) |
| self.assertFalse(hasattr(m, "relu")) |
| |
| def test_fusion_pattern_with_multiple_inputs(self): |
| """ This test tests two keys in backend_config: root_node_getter and |
| extra_inputs_getter, |
| root_node_getter is used to identify a "root" module in the node pattern, |
| the node that we'll keep after fusion. |
| extra_inputs_getter will return a list of node that needs to be added to the |
| fused node as extra inputs. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.bn = torch.nn.BatchNorm2d(3) |
| self.relu = torch.nn.ReLU() |
| self.maxpool = torch.nn.MaxPool2d(3) |
| |
| def forward(self, x): |
| y = x |
| y = self.maxpool(x) |
| x = self.conv(x) |
| x = self.bn(x) |
| x = torch.add(x, y) |
| x = self.relu(x) |
| return x |
| |
| m = M().eval() |
| |
| def fuse_conv_bn_relu(is_qat, relu, add_pattern): |
| _, bn_pattern, _ = add_pattern |
| bn, conv = bn_pattern |
| return conv |
| |
| def conv_bn_res_relu_root_node_getter(pattern): |
| relu, add_pattern = pattern |
| _, bn_pattern, _ = add_pattern |
| bn, conv = bn_pattern |
| return conv |
| |
| def conv_bn_res_relu_extra_inputs_getter(pattern): |
| """ get inputs pattern for extra inputs, inputs for root node |
| are assumed to be copied over from root node to the fused node |
| """ |
| relu, add_pattern = pattern |
| _, bn_pattern, extra_input = add_pattern |
| bn, conv = bn_pattern |
| return [extra_input] |
| |
| conv_bn_res_relu_config = BackendPatternConfig() \ |
| ._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \ |
| .set_fuser_method(fuse_conv_bn_relu) \ |
| ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \ |
| ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) |
| backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config) |
| m = fuse_fx(m, backend_config=backend_config) |
| self.assertEqual(type(m.conv), torch.nn.Conv2d) |
| # check bn and relu are gone since we replaced the whole pattern to conv |
| self.assertFalse(hasattr(m, "bn")) |
| self.assertFalse(hasattr(m, "relu")) |
| |
| # check conv module has two inputs |
| named_modules = dict(m.named_modules()) |
| for node in m.graph.nodes: |
| if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d: |
| self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments" |
| |
| def test_fusion_pattern_with_matchallnode(self): |
| """This test tests that the node matched by MatchAllNode will be regared as an input |
| instead of a module to be fused. For instance, we have two patterns: |
| (nn.ReLU, (torch.add, MatchAllNode, nn.Conv2d)) |
| (nn.ReLU, nn.Conv2d) |
| And we wanna fuse the following model |
| Conv2d -> ReLU + |
| Conv2d ------ Add -> ReLU |
| ReLU in the first row is matched as MatchAllNode in the residual pattern. But it won't be |
| fused as part of that pattnern. It needs to be properly fused with the upstream Conv2d. |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| self.relu1 = torch.nn.ReLU() |
| self.conv2 = torch.nn.Conv2d(3, 3, 3) |
| self.relu2 = torch.nn.ReLU() |
| |
| def forward(self, x): |
| y = self.conv1(x) |
| y = self.relu1(y) |
| |
| x = self.conv2(x) |
| x = torch.add(x, y) |
| x = self.relu2(x) |
| return x |
| |
| m = M().eval() |
| |
| def fuse_conv_relu(is_qat, conv, relu): |
| return conv |
| |
| def fuse_conv_res_relu(is_qat, relu, add_pattern): |
| _, conv, _ = add_pattern |
| return conv |
| |
| def conv_res_relu_root_node_getter(pattern): |
| relu, (_, conv, _) = pattern |
| return conv |
| |
| def conv_res_relu_extra_inputs_getter(pattern): |
| relu, (_, _, extra_input) = pattern |
| return [extra_input] |
| |
| conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \ |
| .set_fuser_method(fuse_conv_relu) |
| conv_res_relu_config = BackendPatternConfig() \ |
| ._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \ |
| .set_fuser_method(fuse_conv_res_relu) \ |
| ._set_root_node_getter(conv_res_relu_root_node_getter) \ |
| ._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter) |
| backend_config = BackendConfig() \ |
| .set_backend_pattern_config(conv_relu_config) \ |
| .set_backend_pattern_config(conv_res_relu_config) |
| m = fuse_fx(m, backend_config=backend_config) |
| self.assertEqual(type(m.conv1), torch.nn.Conv2d) |
| self.assertEqual(type(m.conv2), torch.nn.Conv2d) |
| # check relu are gone since we replaced both patterns to conv |
| self.assertFalse(hasattr(m, "relu1")) |
| self.assertFalse(hasattr(m, "relu2")) |
| |
| |
| @skipIfNoFBGEMM |
| class TestQuantizeFx(QuantizationTestCase): |
| def test_pattern_match(self): |
| """ test MatchAllNode with |
| conv - bn - add - relu pattern |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(1, 1, 1) |
| self.bn = nn.BatchNorm2d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x, y): |
| x = self.conv(x) |
| x = self.bn(x) |
| x = x + y |
| x = self.relu(x) |
| return x |
| |
| pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) |
| m = torch.fx.symbolic_trace(M()) |
| modules = dict(m.named_modules()) |
| for n in m.graph.nodes: |
| if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: |
| self.assertTrue(_is_match(modules, n, pattern)) |
| |
| def test_pattern_match_constant(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| x, _ = torch.ops.aten.max_pool2d_with_indices.default(x) |
| return x |
| |
| pattern = (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) |
| m = torch.fx.symbolic_trace(M()) |
| # eliminate the code that get the second output of maxpool, so that the pattern |
| # can be matched |
| m.graph.eliminate_dead_code() |
| modules = dict(m.named_modules()) |
| for n in m.graph.nodes: |
| if n.op == "call_function" and n.target == operator.getitem: |
| self.assertTrue(_is_match(modules, n, pattern)) |
| |
| def test_fused_module_qat_swap(self): |
| class Tmp(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tmp = torch.nn.Linear(5, 5) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.tmp(x) |
| return self.relu(x) |
| |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential(Tmp(), torch.nn.Linear(5, 5)) |
| self.mods2 = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| a = self.mods1(x) |
| x = torch.add(x, 5) |
| x = self.mods2(x) |
| x = torch.add(x, 5) |
| return a, x |
| |
| |
| model = M().train() |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (torch.nn.Linear, default_qat_qconfig), |
| (torch.nn.ReLU, default_qat_qconfig), |
| ], |
| } |
| prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) |
| self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU)) |
| |
| def _get_conv_linear_test_cases(self, is_reference): |
| """ Returns a list of test cases, with format: |
| is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_op |
| """ |
| class FunctionalConv1d(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| self.stride = 1 |
| self.padding = 0 |
| self.dilation = 1 |
| self.groups = 1 |
| |
| def forward(self, x): |
| return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups) |
| |
| |
| class Conv1d(torch.nn.Module): |
| def __init__(self, *args): |
| super().__init__() |
| self.conv = torch.nn.Conv1d(*args) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| conv1d_input = torch.rand(1, 3, 224) |
| conv1d_weight = torch.rand(3, 3, 3) |
| conv1d_module_args = (3, 3, 3) |
| |
| class FunctionalConv2d(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| self.stride = (1, 1) |
| self.padding = (0, 0) |
| self.dilation = (1, 1) |
| self.groups = 1 |
| |
| def forward(self, x): |
| return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups) |
| |
| class Conv2d(torch.nn.Module): |
| def __init__(self, *args): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(*args) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| conv2d_input = torch.rand(1, 3, 224, 224) |
| conv2d_weight = torch.rand(3, 3, 3, 3) |
| conv2d_module_args = (3, 3, 3) |
| |
| class FunctionalConv3d(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| self.stride = (1, 1, 1) |
| self.padding = (0, 0, 0) |
| self.dilation = (1, 1, 1) |
| self.groups = 1 |
| |
| def forward(self, x): |
| return F.conv3d( |
| x, |
| self.weight, |
| None, |
| self.stride, |
| self.padding, |
| self.dilation, |
| self.groups, |
| ) |
| |
| class Conv3d(torch.nn.Module): |
| def __init__(self, *args): |
| super().__init__() |
| self.conv = torch.nn.Conv3d(*args) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| conv3d_input = torch.rand(1, 3, 32, 224, 224) |
| conv3d_weight = torch.rand(3, 3, 3, 3, 3) |
| conv3d_module_args = (3, 3, 3) |
| |
| class Linear(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight) |
| |
| linear_input = torch.rand(8, 5) |
| linear_weight = torch.rand(10, 5) |
| |
| class LinearModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| linear_module_input = torch.rand(8, 5) |
| |
| # is_dynamic, ModuleClass, module_constructor_inputs, |
| # inputs, quantized_node, weight_prepack_node |
| tests = [ |
| ( |
| False, |
| FunctionalConv1d, |
| (conv1d_weight,), |
| (conv1d_input,), |
| ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) , |
| ns.call_function(torch.ops.quantized.conv1d_prepack), |
| ), |
| ( |
| False, |
| FunctionalConv2d, |
| (conv2d_weight,), |
| (conv2d_input,), |
| ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d), |
| ns.call_function(torch.ops.quantized.conv2d_prepack), |
| ), |
| ( |
| False, |
| FunctionalConv3d, |
| (conv3d_weight,), |
| (conv3d_input,), |
| ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d), |
| ns.call_function(torch.ops.quantized.conv3d_prepack), |
| ), |
| ( |
| False, |
| Conv1d, |
| conv1d_module_args, |
| (conv1d_input,), |
| ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d), |
| None |
| ), |
| ( |
| False, |
| Conv2d, |
| conv2d_module_args, |
| (conv2d_input,), |
| ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d), |
| None |
| ), |
| ( |
| False, |
| Conv3d, |
| conv3d_module_args, |
| (conv3d_input,), |
| ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d), |
| None |
| ), |
| ( |
| True, |
| Linear, |
| (linear_weight,), |
| (linear_input,), |
| None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic), |
| ns.call_function(torch.ops.quantized.linear_prepack), |
| ), |
| ( |
| False, |
| Linear, |
| (linear_weight,), |
| (linear_input,), |
| ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear), |
| ns.call_function(torch.ops.quantized.linear_prepack), |
| ), |
| ( |
| True, |
| LinearModule, |
| (), |
| (linear_module_input,), |
| ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear), |
| None, |
| ), |
| ( |
| False, |
| LinearModule, |
| (), |
| (linear_module_input,), |
| ns.call_module(nnqr.Linear if is_reference else nnq.Linear), |
| None, |
| ), |
| ] |
| return tests |
| |
| @skipIfNoFBGEMM |
| def test_conv_linear_not_reference(self): |
| """ Test quantizing conv and linear |
| """ |
| tests = self._get_conv_linear_test_cases(is_reference=False) |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = {} |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node=quantized_node, |
| expected_node_occurrence=node_occurrence, |
| is_reference=False) |
| |
| @skipIfNoFBGEMM |
| def test_conv_linear_reference(self): |
| """ Test quantizing functional conv and linear with reference option |
| """ |
| tests = self._get_conv_linear_test_cases(is_reference=True) |
| |
| def _get_keys(prefix, is_dynamic): |
| all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] |
| if not is_dynamic: |
| all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) |
| return all_keys |
| |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = {} |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| result_dict = self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node=quantized_node, |
| expected_node_occurrence=node_occurrence, |
| is_reference=True) |
| qr = result_dict["quantized_reference"] |
| |
| def checkWeightQParams(model): |
| for module_name in ("linear", "conv"): |
| if hasattr(model, module_name): |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) |
| self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) |
| |
| def checkSerDeser(model, is_dynamic): |
| for module_name in ("linear", "conv"): |
| if hasattr(model, module_name): |
| # make sure seralization works |
| state_dict = copy.deepcopy(model.state_dict()) |
| all_keys = _get_keys(module_name, is_dynamic) |
| for key in all_keys: |
| self.assertTrue(key in state_dict) |
| # check load_state_dict restores states |
| module = getattr(model, module_name) |
| prev_scale = module.weight_scale |
| module.weight_scale = None |
| model.load_state_dict(state_dict) |
| module = getattr(model, module_name) |
| self.assertTrue(torch.equal(prev_scale, module.weight_scale)) |
| |
| |
| checkWeightQParams(qr) |
| qr = copy.deepcopy(qr) |
| # make sure the qparams are preserved after copy |
| checkWeightQParams(qr) |
| |
| checkSerDeser(qr, is_dynamic) |
| |
| def _get_conv_transpose_test_cases(self, use_relu, is_reference): |
| """ Returns a list of test cases, with format: |
| is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_op |
| """ |
| class FunctionalConvTranspose1d(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| self.stride = 1 |
| self.padding = 0 |
| self.output_padding = 0 |
| self.dilation = 1 |
| self.groups = 1 |
| |
| def forward(self, x): |
| y = F.conv_transpose1d( |
| x, |
| self.weight, |
| None, |
| self.stride, |
| self.padding, |
| self.output_padding, |
| self.groups, |
| self.dilation |
| ) |
| if use_relu: |
| y = F.relu(y) |
| return y |
| |
| class ConvTranspose1d(torch.nn.Module): |
| def __init__(self, *args): |
| super().__init__() |
| self.deconv = torch.nn.ConvTranspose1d(*args) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| y = self.deconv(x) |
| if use_relu: |
| y = self.relu(y) |
| return y |
| |
| conv_transpose1d_input = torch.rand(1, 3, 224) |
| conv_transpose1d_weight = torch.rand(3, 3, 3) |
| conv_transpose1d_module_args = (3, 3, 3) |
| |
| class FunctionalConvTranspose2d(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| self.stride = (1, 1) |
| self.padding = (0, 0) |
| self.output_padding = (0, 0) |
| self.dilation = (1, 1) |
| self.groups = 1 |
| |
| def forward(self, x): |
| y = F.conv_transpose2d( |
| x, |
| self.weight, |
| None, |
| self.stride, |
| self.padding, |
| self.output_padding, |
| self.groups, |
| self.dilation |
| ) |
| if use_relu: |
| y = F.relu(y) |
| return y |
| |
| class ConvTranspose2d(torch.nn.Module): |
| def __init__(self, *args): |
| super().__init__() |
| self.deconv = torch.nn.ConvTranspose2d(*args) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| y = self.deconv(x) |
| if use_relu: |
| y = self.relu(y) |
| return y |
| |
| conv_transpose2d_input = torch.rand(1, 3, 224, 224) |
| conv_transpose2d_weight = torch.rand(3, 3, 3, 3) |
| conv_transpose2d_module_args = (3, 3, 3) |
| |
| class FunctionalConvTranspose3d(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| self.stride = (1, 1, 1) |
| self.padding = (0, 0, 0) |
| self.output_padding = (0, 0, 0) |
| self.dilation = (1, 1, 1) |
| self.groups = 1 |
| |
| def forward(self, x): |
| y = F.conv_transpose3d( |
| x, |
| self.weight, |
| None, |
| self.stride, |
| self.padding, |
| self.output_padding, |
| self.groups, |
| self.dilation |
| ) |
| if use_relu: |
| y = F.relu(y) |
| return y |
| |
| class ConvTranspose3d(torch.nn.Module): |
| def __init__(self, *args): |
| super().__init__() |
| self.deconv = torch.nn.ConvTranspose3d(*args) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| y = self.deconv(x) |
| if use_relu: |
| y = self.relu(y) |
| return y |
| |
| conv_transpose3d_input = torch.rand(1, 3, 32, 224, 224) |
| conv_transpose3d_weight = torch.rand(3, 3, 3, 3, 3) |
| conv_transpose3d_module_args = (3, 3, 3) |
| |
| # is_dynamic, ModuleClass, module_constructor_inputs, |
| # inputs, quantized_node, weight_prepack_node |
| tests = [ |
| ( |
| False, |
| FunctionalConvTranspose1d, |
| (conv_transpose1d_weight,), |
| (conv_transpose1d_input,), |
| ns.call_function( |
| torch.nn.functional.conv_transpose1d if is_reference else torch.ops.quantized.conv_transpose1d |
| ), |
| ns.call_function(torch.ops.quantized.conv_transpose1d_prepack), |
| ), |
| ( |
| False, |
| FunctionalConvTranspose2d, |
| (conv_transpose2d_weight,), |
| (conv_transpose2d_input,), |
| ns.call_function( |
| torch.nn.functional.conv_transpose2d if is_reference else torch.ops.quantized.conv_transpose2d |
| ), |
| ns.call_function(torch.ops.quantized.conv_transpose2d_prepack), |
| ), |
| ( |
| False, |
| FunctionalConvTranspose3d, |
| (conv_transpose3d_weight,), |
| (conv_transpose3d_input,), |
| ns.call_function( |
| torch.nn.functional.conv_transpose3d if is_reference else torch.ops.quantized.conv_transpose3d), |
| ns.call_function(torch.ops.quantized.conv_transpose3d_prepack), |
| ), |
| ( |
| False, |
| ConvTranspose1d, |
| conv_transpose1d_module_args, |
| (conv_transpose1d_input,), |
| ns.call_module(nnqr.ConvTranspose1d if is_reference else nnq.ConvTranspose1d), |
| None |
| ), |
| ( |
| False, |
| ConvTranspose2d, |
| conv_transpose2d_module_args, |
| (conv_transpose2d_input,), |
| ns.call_module(nnqr.ConvTranspose2d if is_reference else nnq.ConvTranspose2d), |
| None |
| ), |
| ( |
| False, |
| ConvTranspose3d, |
| conv_transpose3d_module_args, |
| (conv_transpose3d_input,), |
| ns.call_module(nnqr.ConvTranspose3d if is_reference else nnq.ConvTranspose3d), |
| None |
| ), |
| ] |
| return tests |
| |
| @skipIfNoFBGEMM |
| def test_conv_transpose_not_reference(self): |
| """ Test quantizing transposed conv |
| """ |
| tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=False) |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = {} |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node=quantized_node, |
| expected_node_occurrence=node_occurrence, |
| is_reference=False) |
| |
| @skipIfNoFBGEMM |
| def test_conv_transpose_reference(self): |
| """ Test quantizing transposed conv with reference option |
| """ |
| tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=True) |
| |
| def _get_keys(prefix, is_dynamic): |
| all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] |
| if not is_dynamic: |
| all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) |
| return all_keys |
| |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = {} |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| result_dict = self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node=quantized_node, |
| expected_node_occurrence=node_occurrence, |
| is_reference=True) |
| qr = result_dict["quantized_reference"] |
| |
| def checkWeightQParams(model): |
| module_name = "deconv" |
| if hasattr(model, module_name): |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) |
| self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) |
| |
| def checkSerDeser(model, is_dynamic): |
| module_name = "deconv" |
| if hasattr(model, module_name): |
| # make sure seralization works |
| state_dict = copy.deepcopy(model.state_dict()) |
| all_keys = _get_keys(module_name, is_dynamic) |
| for key in all_keys: |
| self.assertTrue(key in state_dict) |
| # check load_state_dict restores states |
| module = getattr(model, module_name) |
| prev_scale = module.weight_scale |
| module.weight_scale = None |
| model.load_state_dict(state_dict) |
| module = getattr(model, module_name) |
| self.assertTrue(torch.equal(prev_scale, module.weight_scale)) |
| |
| |
| checkWeightQParams(qr) |
| qr = copy.deepcopy(qr) |
| # make sure the qparams are preserved after copy |
| checkWeightQParams(qr) |
| |
| checkSerDeser(qr, is_dynamic) |
| |
| def test_conv_transpose_relu_not_reference(self): |
| """ Test quantizing transposed conv + relu |
| Fusion with relu is not supported. |
| """ |
| tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=False) |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = {} |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| if quantized_node.op == 'call_module': |
| node_occurrence[ns.call_module(nn.ReLU)] = 1 |
| else: |
| node_occurrence[ns.call_function(F.relu)] = 1 |
| self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node=quantized_node, |
| expected_node_occurrence=node_occurrence, |
| is_reference=False) |
| |
| @skipIfNoFBGEMM |
| def test_conv_transpose_relu_reference(self): |
| """ Test quantizing transposed conv with reference option |
| Fusion with relu is not supported. |
| """ |
| tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=True) |
| |
| def _get_keys(prefix, is_dynamic): |
| all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] |
| if not is_dynamic: |
| all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) |
| return all_keys |
| |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = {} |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| if quantized_node.op == 'call_module': |
| node_occurrence[ns.call_module(nn.ReLU)] = 1 |
| else: |
| node_occurrence[ns.call_function(F.relu)] = 1 |
| result_dict = self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node=quantized_node, |
| expected_node_occurrence=node_occurrence, |
| is_reference=True) |
| qr = result_dict["quantized_reference"] |
| |
| def checkWeightQParams(model): |
| module_name = "deconv" |
| if hasattr(model, module_name): |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) |
| self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) |
| self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) |
| |
| def checkSerDeser(model, is_dynamic): |
| module_name = "deconv" |
| if hasattr(model, module_name): |
| # make sure seralization works |
| state_dict = copy.deepcopy(model.state_dict()) |
| all_keys = _get_keys(module_name, is_dynamic) |
| for key in all_keys: |
| self.assertTrue(key in state_dict) |
| # check load_state_dict restores states |
| module = getattr(model, module_name) |
| prev_scale = module.weight_scale |
| module.weight_scale = None |
| model.load_state_dict(state_dict) |
| module = getattr(model, module_name) |
| self.assertTrue(torch.equal(prev_scale, module.weight_scale)) |
| |
| |
| checkWeightQParams(qr) |
| qr = copy.deepcopy(qr) |
| # make sure the qparams are preserved after copy |
| checkWeightQParams(qr) |
| |
| checkSerDeser(qr, is_dynamic) |
| |
| @skipIfNoFBGEMM |
| def test_dynamic_quant_weight_observer(self): |
| ''' Test that weight observer is run in convert step |
| ''' |
| |
| class M(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight) |
| |
| m = M(torch.rand(1, 1)).eval() |
| qconfig = default_dynamic_qconfig |
| qconfig_dict = {'': qconfig} |
| example_inputs = (torch.rand(1, 1),) |
| prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| quantized = convert_to_reference_fx(prepared) |
| qparams = (quantized._scale_0, quantized._zero_point_0) |
| weight_obs = qconfig.weight() |
| weight_obs(quantized.weight) |
| # Get the actual value to avoid tensor size mismatch error, torch.Size([]) vs torch.Size([1]) |
| ref_qparams = (weight_obs.calculate_qparams()[0].item(), weight_obs.calculate_qparams()[1].item()) |
| self.assertEqual(qparams, ref_qparams) |
| |
| def test_conv_bn_relu(self): |
| """ Tests fusion and quantization for "Conv - Bn" and "Conv - Bn - ReLU" |
| """ |
| convs = { |
| 1: nn.Conv1d, |
| 2: nn.Conv2d, |
| 3: nn.Conv3d, |
| } |
| bns = { |
| 1: nn.BatchNorm1d, |
| 2: nn.BatchNorm2d, |
| 3: nn.BatchNorm3d, |
| } |
| quantized_convs = { |
| 1: nnq.Conv1d, |
| 2: nnq.Conv2d, |
| 3: nnq.Conv3d, |
| } |
| quantized_conv_relus = { |
| 1: nniq.ConvReLU1d, |
| 2: nniq.ConvReLU2d, |
| 3: nniq.ConvReLU3d, |
| } |
| |
| class M(torch.nn.Module): |
| def __init__(self, dim, has_relu): |
| super().__init__() |
| self.conv = convs[dim](3, 3, 3) |
| self.bn = bns[dim](3) |
| self.relu = nn.ReLU() if has_relu else nn.Identity() |
| self.has_relu = has_relu |
| self.quant = QuantStub() |
| self.dequant = DeQuantStub() |
| |
| def forward(self, x): |
| x = self.quant(x) |
| x = self.conv(x) |
| x = self.bn(x) |
| if self.has_relu: |
| x = self.relu(x) |
| x = self.dequant(x) |
| return x |
| |
| options = itertools.product([1, 2, 3], [True, False], self.static_quant_types) |
| for dim, has_relu, quant_type in options: |
| expected_node = ns.call_module( |
| quantized_conv_relus[dim] if has_relu |
| else quantized_convs[dim]) |
| m = M(dim, has_relu) |
| m_eager = copy.deepcopy(m) |
| result_dict = self.checkGraphModeFxOp( |
| m, |
| self.img_data_dict[dim], |
| quant_type, |
| expected_node=expected_node, |
| ) |
| result = result_dict["quantized_output"] |
| |
| # check numerics |
| qengine = torch.backends.quantized.engine |
| if quant_type == QuantType.STATIC: |
| m_eager.eval() |
| qconfig = get_default_qconfig(qengine) |
| prepare_fn = prepare |
| is_qat = False |
| else: |
| m_eager.train() |
| qconfig = get_default_qat_qconfig(qengine) |
| prepare_fn = prepare_qat |
| is_qat = True |
| |
| fuse_list = ["conv", "bn"] |
| if has_relu: |
| fuse_list.append("relu") |
| if is_qat: |
| fuse_modules_qat(m_eager, fuse_list, inplace=True) |
| else: |
| fuse_modules(m_eager, fuse_list, inplace=True) |
| m_eager.qconfig = qconfig |
| m_eager = prepare_fn(m_eager) |
| prepared_fx = result_dict["prepared"] |
| |
| m_eager(*self.img_data_dict[dim][0]) |
| m_eager = convert(m_eager) |
| result_eager = m_eager(*self.img_data_dict[dim][0]) |
| self.assertEqual(result, result_eager) |
| |
| def test_linear_bn(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(4, 4) |
| self.bn = nn.BatchNorm1d(4) |
| self.quant = QuantStub() |
| self.dequant = DeQuantStub() |
| |
| def forward(self, x): |
| x = self.quant(x) |
| x = self.linear(x) |
| x = self.bn(x) |
| x = self.dequant(x) |
| return x |
| |
| data = (torch.randn(4, 4),) |
| for quant_type in self.static_quant_types: |
| expected_node = ns.call_module(nnq.Linear) |
| m = M() |
| m_eager = copy.deepcopy(m) |
| result_dict = self.checkGraphModeFxOp(m, data, quant_type, expected_node=expected_node) |
| result = result_dict["quantized_output"] |
| |
| # check numerics vs eager mode |
| fuse_list = ["linear", "bn"] |
| qengine = torch.backends.quantized.engine |
| if quant_type == QuantType.STATIC: |
| m_eager.eval() |
| qconfig = get_default_qconfig(qengine) |
| prepare_fn = prepare |
| fuse_modules(m_eager, fuse_list, inplace=True) |
| else: |
| m_eager.train() |
| qconfig = get_default_qat_qconfig(qengine) |
| prepare_fn = prepare_qat |
| fuse_modules_qat(m_eager, fuse_list, inplace=True) |
| m_eager.qconfig = qconfig |
| m_eager = prepare_fn(m_eager) |
| m_eager(*data) |
| m_eager = convert(m_eager) |
| result_eager = m_eager(*data) |
| self.assertEqual(result, result_eager) |
| |
| @skipIfNoFBGEMM |
| def test_dynamic_quant_fp16(self): |
| with override_quantized_engine('fbgemm'): |
| class Linear(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight) |
| |
| linear_input = torch.rand(8, 5) |
| linear_weight = torch.rand(10, 5) |
| |
| class LinearModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| linear_module_input = torch.rand(8, 5) |
| |
| tests = [ |
| (Linear, (linear_weight,), (linear_input,), |
| ns.call_function(torch.ops.quantized.linear_dynamic_fp16), |
| ns.call_function(torch.ops.quantized.linear_prepack_fp16)), |
| (LinearModule, (), (linear_module_input,), |
| ns.call_module(nnqd.Linear), |
| None), |
| ] |
| for (ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| for is_reference in [True, False]: |
| node_occurrence = {} |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| m = ModuleClass(*module_constructor_inputs).eval() |
| qconfig_dict = {"": float16_dynamic_qconfig} |
| m = prepare_fx(m, qconfig_dict, example_inputs=inputs) |
| convert_fn = convert_to_reference_fx if is_reference else convert_fx |
| m = convert_fn(m) |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| |
| |
| @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") |
| @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| @override_qengines |
| def test_qat_prepare_device_affinity(self): |
| """ |
| Tests that FX QAT prepare pass respects device affinity |
| """ |
| class Model(nn.Module): |
| |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(1, 1, 1) |
| self.bn = nn.BatchNorm2d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| x = self.relu(x) |
| return x |
| |
| model = Model() |
| qengine = torch.backends.quantized.engine |
| qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} |
| device = torch.device('cuda:0') |
| model.to(device) |
| |
| example_inputs = (torch.randn(4, 1, 4, 4, device=device),) |
| # QAT prepare |
| model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) |
| |
| # ensure that running an input on CUDA works without any needed changes |
| model(*example_inputs) |
| |
| # ensure all buffers and parameters are on the device we expect |
| model_devices = {p.device for p in model.parameters()} | \ |
| {p.device for p in model.buffers()} |
| self.assertEqual(len(model_devices), 1) |
| model_device = next(iter(model_devices)) |
| self.assertEqual(model_device, device) |
| |
| @skipIfNoFBGEMM |
| def test_dict_output(self): |
| """ Make sure quantization runs for models with dictionary output |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| return {"output": self.conv(x["input"])} |
| |
| example_inputs = ({"input": torch.randn(1, 1, 1, 1)},) |
| m = M().eval() |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| |
| @override_qengines |
| def test_attention(self): |
| """ Make sure quantization runs for a corner case in attention module |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| q, k, v = x.chunk(3, dim=0) |
| q = q.contiguous().view(-1, 1).transpose(0, 1) |
| k = k.contiguous().view(-1, 1).transpose(0, 1) |
| v = v.contiguous().view(-1, 1).transpose(0, 1) |
| torch._assert( |
| k.size(1) == 1, "key size should be equal to 1" |
| ) |
| r = torch.mm(k, v) |
| return q * k + r |
| |
| example_inputs = (torch.randn(3, 1, 1, 1),) |
| m = M().eval() |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (nn.Conv2d, default_qconfig), |
| ] |
| } |
| # make sure it runs |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| |
| def _test_standalone_module( |
| self, |
| interface_config, |
| prepare_count_check, |
| standalone_prepare_count_check, |
| convert_count_check, |
| standalone_convert_count_check): |
| """ Test standalone module with different quantized input/quantized output |
| configurations |
| """ |
| class StandaloneModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| self.standalone = StandaloneModule() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.standalone(x) |
| return x |
| |
| class RefM(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| # instantiate M and RefM and align the parameters |
| original_m = M().eval() |
| original_ref_m = RefM().eval() |
| original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) |
| original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) |
| original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) |
| original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) |
| |
| for is_name in [True, False]: |
| sm_example_inputs = example_inputs |
| if is_name: |
| prepare_config = { |
| "standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)] |
| } |
| else: |
| prepare_config = { |
| "standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)] |
| } |
| |
| original_m_copy = copy.deepcopy(original_m) |
| original_ref_m_copy = copy.deepcopy(original_ref_m) |
| |
| qconfig_dict = {"": default_qconfig} |
| # check prepared model |
| m = prepare_fx( |
| original_m_copy, |
| qconfig_dict, |
| example_inputs=example_inputs, |
| prepare_custom_config=prepare_config) |
| # calibration |
| m(*example_inputs) |
| self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) |
| self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) |
| |
| # check converted/quantized model |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) |
| self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) |
| res = m(*example_inputs) |
| |
| # quantize the reference model |
| ref_m = prepare_fx( |
| original_ref_m_copy, |
| qconfig_dict, |
| example_inputs=example_inputs, |
| ) |
| ref_m(*example_inputs) |
| ref_m = convert_fx(ref_m) |
| ref_res = ref_m(*example_inputs) |
| self.assertEqual(res, ref_res) |
| |
| def test_standalone_module_float_interface(self): |
| float_interface_config = { |
| "input_quantized_idxs": [], # float input |
| "output_quantized_idxs": [], # float output |
| } |
| interface_config = float_interface_config |
| # input and output of first conv, observer for standalone module |
| # will be inserted in the standalone module itself |
| prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 2 |
| } |
| # for input and output of conv in the standalone module |
| standalone_prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 2 |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_module(nnq.Conv2d) : 1, |
| ns.call_method("dequantize") : 1, |
| } |
| standalone_convert_count_check = { |
| # standalone module will take float as input and output |
| # so we'll see quantize and dequantize in the modoule |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_module(nnq.Conv2d): 1, |
| ns.call_method("dequantize") : 1, |
| } |
| self._test_standalone_module( |
| interface_config, |
| prepare_count_check, |
| standalone_prepare_count_check, |
| convert_count_check, |
| standalone_convert_count_check) |
| |
| def test_standalone_module_quantized_interface(self): |
| quantized_interface_config = { |
| "input_quantized_idxs": [0], # quantized input |
| "output_quantized_idxs": [0], # quantized output |
| } |
| interface_config = quantized_interface_config |
| # observer for input and output of first conv |
| prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 2 |
| } |
| # for output of conv in the standalone module |
| standalone_prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 1 |
| } |
| convert_count_check = { |
| # quantizing input for conv |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_module(nnq.Conv2d) : 1, |
| # dequantizing output of standalone module |
| ns.call_method("dequantize") : 1, |
| } |
| standalone_convert_count_check = { |
| # quantization of input happens in parent module |
| # quantization of output happens in the quantized conv module |
| ns.call_function(torch.quantize_per_tensor) : 0, |
| ns.call_module(nnq.Conv2d): 1, |
| # dequantization for output happens in parent module |
| ns.call_method("dequantize") : 0, |
| } |
| self._test_standalone_module( |
| interface_config, |
| prepare_count_check, |
| standalone_prepare_count_check, |
| convert_count_check, |
| standalone_convert_count_check) |
| |
| @skipIfNoFBGEMM |
| def test_qconfig_none(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| self.conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"": default_qconfig, |
| "module_name": [("conv2", None)]} |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| # first conv is quantized, second conv is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.Conv2d), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_module_type(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(1, 1, 1) |
| self.linear = nn.Linear(9, 3) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = x.reshape((1, -1)) |
| x = self.linear(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} |
| example_inputs = (torch.randn(1, 1, 3, 3),) |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| # conv is quantized, linear is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.Linear), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_qat_module_type(self): |
| class LinearRelu(nn.Sequential): |
| def __init__(self): |
| super().__init__( |
| nn.Linear(5, 5), |
| nn.ReLU(), |
| ) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lin_relu = LinearRelu() |
| self.linear = nn.Linear(5, 5) |
| |
| def forward(self, x): |
| x = self.lin_relu(x) |
| x = self.linear(x) |
| return x |
| |
| model = M().train() |
| |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (torch.nn.Linear, default_qat_qconfig), |
| (torch.nn.ReLU, default_qat_qconfig), |
| ], |
| } |
| example_inputs = (torch.rand(5, 5),) |
| m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nniq.LinearReLU), |
| ns.call_module(nnq.Linear), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_function(self): |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| return x + y |
| |
| m = M().eval() |
| qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} |
| data = torch.randn(1, 1, 1, 1) |
| example_inputs = (data, data) |
| m = prepare_fx(m, qconfig_dict, example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| # first conv is quantized, second conv is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_module_name_regex(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| self.conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| # first conv is quantized, second conv is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_precedence(self): |
| for device in get_supported_device_types(): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(1, 1) |
| self.conv = nn.Conv2d(1, 1, 1) |
| self.module_conv1 = nn.Conv2d(1, 1, 1) |
| self.module_conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| # global |
| x = self.linear(x) |
| # global + object_type --> object_type |
| x = self.conv(x) |
| # global + object_type + module_name_regex --> module_name_regex |
| x = self.module_conv1(x) |
| # global + object_type + module_name_regex + module_name --> module_name |
| x = self.module_conv2(x) |
| return x |
| |
| m = M().to(device).eval() |
| |
| global_qconfig = default_qconfig |
| object_type_qconfig = default_dynamic_qconfig |
| module_name_regex_qconfig = float16_dynamic_qconfig |
| module_name_qconfig = default_qat_qconfig |
| qconfig_dict = { |
| "": global_qconfig, |
| "object_type": [(nn.Conv2d, object_type_qconfig)], |
| "module_name_regex": [("module_conv*", module_name_regex_qconfig)], |
| "module_name": [("module_conv2", module_name_qconfig)]} |
| m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),)) |
| self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func) |
| self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func) |
| self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func) |
| self.assertEqual(m_prep.conv.qconfig.weight.p.func, object_type_qconfig.weight.p.func) |
| self.assertEqual(m_prep.module_conv1.qconfig.activation.p.func, module_name_regex_qconfig.activation.p.func) |
| self.assertEqual(m_prep.module_conv1.qconfig.weight.p.func, module_name_regex_qconfig.weight.p.func) |
| self.assertEqual(m_prep.module_conv2.qconfig.activation.p.func, module_name_qconfig.activation.p.func) |
| self.assertEqual(m_prep.module_conv2.qconfig.weight.p.func, module_name_qconfig.weight.p.func) |
| |
| def test_qconfig_module_name_object_type_order(self): |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(1, 1) |
| self.fc2 = nn.Linear(1, 1) |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.fc2(x) |
| x = torch.add(x, x) |
| x = torch.add(x, x) |
| return x |
| |
| class M2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(1, 1) |
| self.fc2 = nn.Linear(1, 1) |
| self.m1 = M1() |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.fc2(x) |
| x = torch.add(x, x) |
| x = torch.add(x, x) |
| x = self.m1(x) |
| return x |
| |
| class M3(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(1, 1) |
| self.fc2 = nn.Linear(1, 1) |
| self.m2 = M2() |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.fc2(x) |
| x = torch.add(x, x) |
| x = torch.add(x, x) |
| x = self.m2(x) |
| return x |
| |
| m = M3().eval() |
| qconfig_dict = { |
| "module_name_object_type_order": [ |
| # test various FQNs: global, single child, multiple children |
| ("", nn.Linear, 0, torch.ao.quantization.default_qconfig), |
| ("", torch.add, 0, torch.ao.quantization.default_qconfig), |
| ("m2", nn.Linear, 1, torch.ao.quantization.default_qconfig), |
| ("m2", torch.add, 1, torch.ao.quantization.default_qconfig), |
| ("m2.m1", nn.Linear, 0, torch.ao.quantization.default_qconfig), |
| ("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig), |
| ], |
| } |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| |
| node_list = [ |
| # m3 |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Linear), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.Linear), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_method("dequantize"), |
| ns.call_function(torch.add), |
| # m2 |
| ns.call_module(nn.Linear), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Linear), |
| ns.call_method("dequantize"), |
| ns.call_function(torch.add), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| # m1 |
| ns.call_module(nnq.Linear), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.Linear), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_method("dequantize"), |
| ns.call_function(torch.add), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| # test that function order overrides global qconfig |
| class M4(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fc1 = nn.Linear(1, 1) |
| self.fc2 = nn.Linear(1, 1) |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.fc2(x) |
| x = torch.add(x, x) |
| x = torch.add(x, x) |
| return x |
| |
| m = M4().eval() |
| qconfig_dict = { |
| "": torch.ao.quantization.default_qconfig, |
| "module_name_object_type_order": [ |
| ("", nn.Linear, 1, None), |
| ("", torch.add, 1, None), |
| ], |
| } |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Linear), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.Linear), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_method("dequantize"), |
| ns.call_function(torch.add), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| |
| @override_qengines |
| def test_qconfig_dict_with_fused_modules(self): |
| class LinearReLUModel(torch.nn.Module): |
| def __init__(self, relu): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| self.relu = relu |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.relu(x) |
| return x |
| |
| class ConvReLUModel(torch.nn.Module): |
| def __init__(self, relu): |
| super().__init__() |
| self.conv = torch.nn.Conv1d(3, 3, 3) |
| self.relu = relu |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.relu(x) |
| return x |
| |
| class ConvBnReLUModel(torch.nn.Module): |
| def __init__(self, relu): |
| super().__init__() |
| self.conv = torch.nn.Conv1d(3, 3, 3) |
| self.bn = torch.nn.BatchNorm1d(3) |
| self.relu = relu |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| x = self.relu(x) |
| return x |
| |
| for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]: |
| for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]: |
| m = model(relu).eval() |
| qengine = torch.backends.quantized.engine |
| qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping(qengine) |
| # should not crash as in https://github.com/pytorch/pytorch/issues/75825 |
| prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) |
| |
| # TODO: move QConfigMapping tests to test/quantization/core |
| def test_qconfig_mapping_set_global(self): |
| qconfig = get_default_qconfig() |
| qconfig_mapping = QConfigMapping() |
| self.assertEqual(qconfig_mapping.global_qconfig, None) |
| qconfig_mapping.set_global(qconfig) |
| self.assertEqual(qconfig_mapping.global_qconfig, qconfig) |
| |
| def test_qconfig_mapping_set_object_type(self): |
| qconfig1 = get_default_qconfig() |
| qconfig2 = get_default_qconfig() |
| qconfig3 = get_default_qconfig() |
| self.assertNotEqual(qconfig1, qconfig2) |
| self.assertNotEqual(qconfig1, qconfig3) |
| qconfig_mapping = QConfigMapping() |
| self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 0) |
| # Insert some entries |
| qconfig_mapping.set_object_type(torch.nn.Linear, qconfig1) |
| qconfig_mapping.set_object_type(torch.nn.ReLU, qconfig2) |
| self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 2) |
| self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig1) |
| self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) |
| # Override existing key |
| qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3) |
| self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3) |
| self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) |
| self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3) |
| self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2) |
| self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None) |
| |
| def test_qconfig_mapping_set_module_name_regex(self): |
| qconfig1 = get_default_qconfig() |
| qconfig2 = get_default_qconfig() |
| qconfig3 = get_default_qconfig() |
| self.assertNotEqual(qconfig1, qconfig2) |
| self.assertNotEqual(qconfig1, qconfig3) |
| qconfig_mapping = QConfigMapping() |
| self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 0) |
| # Insert some entries |
| qconfig_mapping.set_module_name_regex("foo.*bar", qconfig1) |
| qconfig_mapping.set_module_name_regex("foo.*", qconfig2) |
| self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 2) |
| self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig1) |
| self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) |
| # Override existing key |
| qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3) |
| self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3) |
| self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) |
| self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3) |
| self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3) |
| self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2) |
| self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2) |
| self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None) |
| |
| def test_qconfig_mapping_set_module_name(self): |
| qconfig1 = get_default_qconfig() |
| qconfig2 = get_default_qconfig() |
| qconfig3 = get_default_qconfig() |
| self.assertNotEqual(qconfig1, qconfig2) |
| self.assertNotEqual(qconfig1, qconfig3) |
| qconfig_mapping = QConfigMapping() |
| self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 0) |
| # Insert some entries |
| qconfig_mapping.set_module_name("mod1", qconfig1) |
| qconfig_mapping.set_module_name("mod2", qconfig2) |
| self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2) |
| self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig1) |
| self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) |
| # Override existing key |
| qconfig_mapping.set_module_name("mod1", qconfig3) |
| self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3) |
| self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) |
| self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3) |
| self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2) |
| self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None) |
| |
| def test_qconfig_mapping_set_module_name_object_type_order(self): |
| qconfig1 = get_default_qconfig() |
| qconfig2 = get_default_qconfig() |
| qconfig3 = get_default_qconfig() |
| self.assertNotEqual(qconfig1, qconfig2) |
| self.assertNotEqual(qconfig1, qconfig3) |
| qconfig_mapping = QConfigMapping() |
| self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 0) |
| # Insert some entries |
| qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig1) |
| qconfig_mapping.set_module_name_object_type_order("mod2", torch.nn.ReLU, 1, qconfig2) |
| self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2) |
| key1 = ("mod1", torch.nn.Linear, 0) |
| key2 = ("mod2", torch.nn.ReLU, 1) |
| self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[0], key1) |
| self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2) |
| self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1) |
| self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2) |
| self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( |
| qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1) |
| self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( |
| qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2) |
| # Override existing key |
| qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3) |
| self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2) |
| self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[0], key1) |
| self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2) |
| self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3) |
| self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2) |
| self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( |
| qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3) |
| self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( |
| qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2) |
| # No match |
| self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( |
| qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None) |
| self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( |
| qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None) |
| self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( |
| qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None) |
| |
| def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2): |
| """ |
| Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods. |
| """ |
| return { |
| _GLOBAL_DICT_KEY: global_qconfig, |
| _OBJECT_TYPE_DICT_KEY: [ |
| (torch.nn.Linear, qconfig1), |
| (torch.nn.ReLU, qconfig2), |
| ], |
| _MODULE_NAME_REGEX_DICT_KEY: [ |
| ("foo.*bar", qconfig1), |
| ("foo.*", qconfig2), |
| ], |
| _MODULE_NAME_DICT_KEY: [ |
| ("bazbaz", qconfig1), |
| ("borbor", qconfig2), |
| ], |
| _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ |
| ("bazbaz", torch.nn.Linear, 0, qconfig1), |
| ("foofoo", torch.nn.ReLU, 1, qconfig2), |
| ], |
| } |
| |
| with self.assertRaises(ValueError) as context: |
| m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) |
| self.assertTrue( |
| 'Expected qconfig_dict to have the following keys:' in str(context.exception) |
| ) |
| self.assertTrue('But found \'object_typo\' instead.' in str(context.exception)) |
| |
| def test_qconfig_mapping_from_dict(self): |
| global_qconfig = QConfig(123, "global") |
| qconfig1 = QConfig(1, "one") |
| qconfig2 = QConfig(2, "two") |
| qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2) |
| qconfig_dict["undefined_dict_key"] = [(123, qconfig1), (234, qconfig2)] |
| qconfig_mapping = QConfigMapping.from_dict(qconfig_dict) |
| self.assertEqual(qconfig_mapping.global_qconfig, global_qconfig) |
| self.assertEqual(qconfig_mapping.object_type_qconfigs, OrderedDict({ |
| torch.nn.Linear: qconfig1, |
| torch.nn.ReLU: qconfig2, |
| })) |
| self.assertEqual(qconfig_mapping.module_name_regex_qconfigs, OrderedDict({ |
| "foo.*bar": qconfig1, |
| "foo.*": qconfig2, |
| })) |
| self.assertEqual(qconfig_mapping.module_name_qconfigs, OrderedDict({ |
| "bazbaz": qconfig1, |
| "borbor": qconfig2, |
| })) |
| self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs, OrderedDict({ |
| ("bazbaz", torch.nn.Linear, 0): qconfig1, |
| ("foofoo", torch.nn.ReLU, 1): qconfig2, |
| })) |
| |
| def test_qconfig_mapping_to_dict(self): |
| global_qconfig = QConfig(123, "global") |
| qconfig1 = QConfig(1, "one") |
| qconfig2 = QConfig(2, "two") |
| qconfig_mapping = QConfigMapping().set_global(global_qconfig) \ |
| .set_object_type(torch.nn.Linear, qconfig1) \ |
| .set_object_type(torch.nn.ReLU, qconfig2) \ |
| .set_module_name_regex("foo.*bar", qconfig1) \ |
| .set_module_name_regex("foo.*", qconfig2) \ |
| .set_module_name("bazbaz", qconfig1) \ |
| .set_module_name("borbor", qconfig2) \ |
| .set_module_name_object_type_order("bazbaz", torch.nn.Linear, 0, qconfig1) \ |
| .set_module_name_object_type_order("foofoo", torch.nn.ReLU, 1, qconfig2) |
| qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2) |
| self.assertEqual(qconfig_mapping.to_dict(), qconfig_dict) |
| |
| def test_qconfig_mapping_repr(self): |
| self.assertTrue(isinstance(get_default_qconfig_mapping().__repr__(), str)) |
| |
| def test_default_qconfig_mapping_override_global(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| m = M().eval() |
| my_qconfig = QConfig(activation=MinMaxObserver, weight=default_weight_observer) |
| qconfig_mapping = get_default_qconfig_mapping() |
| # Override global qconfig |
| old_global_qconfig = qconfig_mapping.global_qconfig |
| qconfig_mapping.set_global(my_qconfig) |
| # Verify the correct qconfig was used |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| m = prepare_fx(m, qconfig_mapping, example_inputs) |
| self.assertTrue(isinstance(old_global_qconfig.activation(), HistogramObserver)) |
| self.assertTrue(isinstance(my_qconfig.activation(), MinMaxObserver)) |
| self.assertTrue(hasattr(m, "activation_post_process_0")) |
| self.assertTrue(hasattr(m, "activation_post_process_1")) |
| self.assertTrue(isinstance(m.activation_post_process_0, MinMaxObserver)) |
| self.assertTrue(isinstance(m.activation_post_process_1, MinMaxObserver)) |
| |
| # Dummy classes for PrepareCustomConfig testing |
| |
| class _DummyStandaloneModule: |
| pass |
| |
| class _DummyFloatModule: |
| pass |
| |
| class _DummyObservedModule: |
| pass |
| |
| class _DummyQuantizedModule: |
| pass |
| |
| class _DummyNonTraceableModule1: |
| pass |
| |
| class _DummyNonTraceableModule2: |
| pass |
| |
| def test_prepare_custom_config_set_standalone_module_name(self): |
| qconfig_mapping = QConfigMapping() |
| example_inputs = (torch.randn(3),) |
| child_prepare_custom_config = PrepareCustomConfig() |
| backend_config = BackendConfig("my_backend") |
| config_entry = StandaloneModuleConfigEntry( |
| qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.standalone_module_names), 0) |
| prepare_custom_config.set_standalone_module_name( |
| "module1", qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) |
| self.assertEqual(list(prepare_custom_config.standalone_module_names.keys()), ["module1"]) |
| self.assertEqual(prepare_custom_config.standalone_module_names["module1"], config_entry) |
| |
| def test_prepare_custom_config_set_standalone_module_class(self): |
| qconfig_mapping = QConfigMapping() |
| example_inputs = (torch.randn(3),) |
| child_prepare_custom_config = PrepareCustomConfig() |
| backend_config = BackendConfig("my_backend") |
| config_entry = StandaloneModuleConfigEntry( |
| qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.standalone_module_classes), 0) |
| prepare_custom_config.set_standalone_module_class( |
| self._DummyStandaloneModule, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) |
| self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1) |
| self.assertTrue(self._DummyStandaloneModule in prepare_custom_config.standalone_module_classes) |
| self.assertEqual(prepare_custom_config.standalone_module_classes[self._DummyStandaloneModule], config_entry) |
| |
| def test_prepare_custom_config_set_float_to_observed_mapping(self): |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 0) |
| prepare_custom_config.set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule, QuantType.STATIC) |
| self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1) |
| self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC]) |
| self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1) |
| self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]) |
| self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule], |
| self._DummyObservedModule) |
| |
| def test_prepare_custom_config_set_non_traceable_module_names(self): |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.non_traceable_module_names), 0) |
| prepare_custom_config.set_non_traceable_module_names(["module1", "module2"]) |
| self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module1", "module2"]) |
| |
| def test_prepare_custom_config_set_non_traceable_module_classes(self): |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.non_traceable_module_classes), 0) |
| prepare_custom_config.set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) |
| self.assertEqual(prepare_custom_config.non_traceable_module_classes, |
| [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) |
| |
| def test_prepare_custom_config_set_input_quantized_indexes(self): |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.input_quantized_indexes), 0) |
| prepare_custom_config.set_input_quantized_indexes([0, 1]) |
| self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1]) |
| |
| def test_prepare_custom_config_set_output_quantized_indexes(self): |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.output_quantized_indexes), 0) |
| prepare_custom_config.set_output_quantized_indexes([0, 1]) |
| self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1]) |
| |
| def test_prepare_custom_config_set_preserved_attributes(self): |
| prepare_custom_config = PrepareCustomConfig() |
| self.assertEqual(len(prepare_custom_config.preserved_attributes), 0) |
| prepare_custom_config.set_preserved_attributes(["attr1", "attr2"]) |
| self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"]) |
| |
| def _get_dummy_prepare_custom_config_dict(self): |
| """ |
| Return a dummy prepare_custom_config_dict to test PrepareCustomConfig's to_dict and from_dict methods. |
| """ |
| return { |
| STANDALONE_MODULE_NAME_DICT_KEY: [( |
| "module1", |
| QConfigMapping(), |
| (torch.randn(3),), |
| PrepareCustomConfig(), |
| BackendConfig("my_backend"), |
| )], |
| STANDALONE_MODULE_CLASS_DICT_KEY: [( |
| self._DummyStandaloneModule, |
| QConfigMapping(), |
| (torch.randn(10),), |
| PrepareCustomConfig(), |
| BackendConfig("my_backend"), |
| )], |
| FLOAT_TO_OBSERVED_DICT_KEY: { |
| "static": { |
| self._DummyFloatModule: self._DummyObservedModule |
| }, |
| }, |
| NON_TRACEABLE_MODULE_NAME_DICT_KEY: ["module2", "module3"], |
| NON_TRACEABLE_MODULE_CLASS_DICT_KEY: [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2], |
| INPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1], |
| OUTPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1], |
| PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"] |
| } |
| |
| def test_prepare_custom_config_from_dict(self): |
| prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict() |
| (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] |
| (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] |
| sm_config_entry1 = StandaloneModuleConfigEntry(qm1, ei1, pcc1, bcd1) |
| sm_config_entry2 = StandaloneModuleConfigEntry(qm2, ei2, pcc2, bcd2) |
| prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config_dict) |
| |
| # Standalone modules |
| self.assertEqual(len(prepare_custom_config.standalone_module_names), 1) |
| self.assertTrue(sm_name in prepare_custom_config.standalone_module_names) |
| self.assertEqual(prepare_custom_config.standalone_module_names[sm_name], sm_config_entry1) |
| self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1) |
| self.assertTrue(sm_class in prepare_custom_config.standalone_module_classes) |
| self.assertEqual(prepare_custom_config.standalone_module_classes[sm_class], sm_config_entry2) |
| |
| # Float to observed mapping |
| self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1) |
| self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC]) |
| self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1) |
| self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]) |
| self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule], |
| self._DummyObservedModule) |
| |
| # Other |
| self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module2", "module3"]) |
| self.assertEqual(prepare_custom_config.non_traceable_module_classes, |
| [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) |
| self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1]) |
| self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1]) |
| self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"]) |
| |
| def test_prepare_custom_config_to_dict(self): |
| prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict() |
| (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] |
| (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] |
| prepare_custom_config = PrepareCustomConfig() \ |
| .set_standalone_module_name(sm_name, qm1, ei1, pcc1, bcd1) \ |
| .set_standalone_module_class(sm_class, qm2, ei2, pcc2, bcd2) \ |
| .set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule) \ |
| .set_non_traceable_module_names(["module2", "module3"]) \ |
| .set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) \ |
| .set_input_quantized_indexes([0, 1]) \ |
| .set_output_quantized_indexes([0, 1]) \ |
| .set_preserved_attributes(["attr1", "attr2"]) |
| # PrepareCustomConfig.to_dict also converts internal QConfigMappings and PrepareCustomConfigs to dicts |
| prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] = (sm_name, qm1.to_dict(), ei1, pcc1.to_dict(), bcd1) |
| prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] = (sm_class, qm2.to_dict(), ei2, pcc2.to_dict(), bcd2) |
| self.assertEqual(prepare_custom_config.to_dict(), prepare_custom_config_dict) |
| |
| def test_convert_custom_config_set_observed_to_quantized_mapping(self): |
| convert_custom_config = ConvertCustomConfig() |
| self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 0) |
| convert_custom_config.set_observed_to_quantized_mapping( |
| self._DummyObservedModule, self._DummyQuantizedModule, QuantType.STATIC) |
| self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1) |
| self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC]) |
| self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]) |
| self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule], |
| self._DummyQuantizedModule) |
| |
| def test_convert_custom_config_set_preserved_attributes(self): |
| convert_custom_config = ConvertCustomConfig() |
| self.assertEqual(len(convert_custom_config.preserved_attributes), 0) |
| convert_custom_config.set_preserved_attributes(["attr1", "attr2"]) |
| self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"]) |
| |
| def _get_dummy_convert_custom_config_dict(self): |
| """ |
| Return a dummy convert_custom_config_dict to test ConvertCustomConfig's to_dict and from_dict methods. |
| """ |
| return { |
| OBSERVED_TO_QUANTIZED_DICT_KEY: { |
| "static": { |
| self._DummyObservedModule: self._DummyQuantizedModule |
| }, |
| }, |
| PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"] |
| } |
| |
| def test_convert_custom_config_from_dict(self): |
| convert_custom_config_dict = self._get_dummy_convert_custom_config_dict() |
| convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config_dict) |
| self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1) |
| self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC]) |
| self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]), 1) |
| self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]) |
| self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule], |
| self._DummyQuantizedModule) |
| self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"]) |
| |
| def test_convert_custom_config_to_dict(self): |
| convert_custom_config = ConvertCustomConfig() \ |
| .set_observed_to_quantized_mapping(self._DummyObservedModule, self._DummyQuantizedModule) \ |
| .set_preserved_attributes(["attr1", "attr2"]) |
| self.assertEqual(convert_custom_config.to_dict(), self._get_dummy_convert_custom_config_dict()) |
| |
| def test_fuse_custom_config_set_preserved_attributes(self): |
| fuse_custom_config = FuseCustomConfig() |
| self.assertEqual(len(fuse_custom_config.preserved_attributes), 0) |
| fuse_custom_config.set_preserved_attributes(["attr1", "attr2"]) |
| self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"]) |
| |
| def test_fuse_custom_config_from_dict(self): |
| fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]} |
| fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config_dict) |
| self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"]) |
| |
| def test_fuse_custom_config_to_dict(self): |
| fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]} |
| fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) |
| self.assertEqual(fuse_custom_config.to_dict(), fuse_custom_config_dict) |
| |
| def test_remove_qconfig(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.avg_pool = torch.nn.AvgPool2d(1) |
| |
| def forward(self, x): |
| return self.avg_pool(x) |
| |
| m = M().eval() |
| qconfig_dict = {'': default_qconfig} |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| for name, module in m.named_modules(): |
| self.assertFalse(hasattr(module, 'qconfig'), |
| 'qconfig is not removed for ' + name) |
| |
| def test_return_none(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| pass |
| |
| m = M().eval() |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) |
| m = convert_fx(m) |
| |
| def test_default_quant_after_none_qconfig(self): |
| """ Make sure default quant is inserted properly""" |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = x.transpose(1, 2) |
| x = self.conv2(x) |
| |
| m = M().eval() |
| qconfig_dict = { |
| "": default_qconfig, |
| "module_name": [ |
| ("conv1", None) |
| ] |
| } |
| m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) |
| m = convert_fx(m) |
| |
| def test_qconfig_for_call_method(self): |
| class Sub(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = x.transpose(2, 3) |
| x = self.conv(x) |
| return x.transpose(2, 3) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = Sub() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.sub(x) |
| x = self.conv2(x) |
| return x.transpose(2, 3) |
| |
| qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]} |
| # since sub is configured to have qconfig None, we should dequantize the output |
| # of self.conv1 and quantize the input of self.conv2 |
| # dequantize after conv2 should happen after transpose since |
| # it is configured with default_qconfig |
| # nodes in Sub module instance is not quantized |
| node_list1 = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ns.call_method("transpose"), |
| ns.call_module(nn.Conv2d), |
| ns.call_method("transpose"), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("transpose"), |
| ns.call_method("dequantize") |
| ] |
| |
| qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]} |
| # Only nodes in Sub module instance are quantized |
| # the first transpose is not quantized because the input is not quantized |
| node_list2 = [ |
| ns.call_module(nn.Conv2d), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("transpose"), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("transpose"), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.Conv2d), |
| ns.call_method("transpose"), |
| ] |
| |
| for qconfig_dict, node_list in [ |
| (qconfig_dict1, node_list1), |
| (qconfig_dict2, node_list2) |
| ]: |
| example_inputs = (torch.randn(2, 1, 3, 3),) |
| m = M().eval() |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m(torch.randn(2, 1, 3, 3)) |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| # make sure it runs |
| m(*example_inputs) |
| |
| def test_qconfig_for_call_func(self): |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| self.mods2 = Linear() |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = self.mods2(x) |
| return x |
| |
| model = M().eval() |
| example_inputs = (torch.rand(5, 5),) |
| qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]} |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| |
| m = convert_fx(m) |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.linear), |
| ns.call_function(torch.ops.quantized.linear), |
| ns.call_method('dequantize'), |
| ns.call_function(torch.nn.functional.linear) |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| m(torch.rand(5, 5)) |
| |
| def test_preserve_attributes(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| m = M() |
| m.eval() |
| m.preserved_attr = 3 |
| prepare_custom_config_dict = { |
| "preserved_attributes": ["preserved_attr"] |
| } |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| m = prepare_fx( |
| m, |
| {"": default_qconfig}, |
| example_inputs=example_inputs, |
| prepare_custom_config=prepare_custom_config_dict) |
| |
| def assertAttrPreserved(m): |
| self.assertTrue(hasattr(m, "preserved_attr")) |
| self.assertEqual(m.preserved_attr, 3) |
| |
| assertAttrPreserved(m) |
| convert_custom_config_dict = { |
| "preserved_attributes": ["preserved_attr"] |
| } |
| m = convert_fx(m, convert_custom_config=convert_custom_config_dict) |
| assertAttrPreserved(m) |
| |
| @skipIfNoFBGEMM |
| def test_qat_and_script(self): |
| model = LinearModelWithSubmodule().train() |
| qengine = torch.backends.quantized.engine |
| qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} |
| x = torch.randn(5, 5) |
| example_inputs = (x,) |
| model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) |
| |
| # ensure scripting works |
| scripted = torch.jit.script(model) |
| # run one round to make sure model runs |
| scripted(x) |
| FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \ |
| .run(scripted.graph) |
| |
| # disable fake_quant and observer |
| for epoch in range(3): |
| if epoch == 1: |
| scripted.apply(torch.ao.quantization.disable_observer) |
| if epoch == 2: |
| scripted.apply(torch.ao.quantization.disable_fake_quant) |
| |
| # ensure the fake_quant and observer have been disabled. |
| matches = ['.fake_quant_enabled', '.observer_enabled'] |
| for key, v in scripted.state_dict().items(): |
| if any(x in key for x in matches): |
| self.assertEqual(v, torch.tensor([0], dtype=torch.int64)) |
| |
| # enable them back |
| scripted.apply(torch.ao.quantization.enable_fake_quant) |
| scripted.apply(torch.ao.quantization.enable_observer) |
| for key, v in scripted.state_dict().items(): |
| if any(x in key for x in matches): |
| self.assertEqual(v, torch.tensor([1], dtype=torch.int64)) |
| |
| @skipIfNoFBGEMM |
| def test_save_observer_state_dict(self): |
| orig = LinearModelWithSubmodule().eval() |
| model = orig |
| qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} |
| x = torch.randn(5, 5) |
| model = prepare_fx(model, qconfig_dict, example_inputs=(x,)) |
| |
| # run it through input |
| model(x) |
| # save state_dict of model |
| obs_dict = torch.ao.quantization.get_observer_state_dict(model) |
| |
| quant = convert_fx(model) |
| |
| b = io.BytesIO() |
| torch.save(obs_dict, b) |
| b.seek(0) |
| |
| # Load the stats into new model |
| model_2 = orig |
| model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,)) |
| |
| loaded_dict = torch.load(b) |
| torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict) |
| |
| quant_2 = convert_fx(model_2) |
| |
| # Verify that loaded state dict produces same results. |
| self.assertEqual(quant(x), quant_2(x)) |
| |
| @skipIfNoFBGEMM |
| def test_custom_module_class(self): |
| class CustomModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class ObservedCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_float(cls, float_module): |
| assert hasattr(float_module, 'qconfig') |
| observed = cls(float_module.linear) |
| observed.qconfig = float_module.qconfig |
| return observed |
| |
| class StaticQuantCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_observed(cls, observed_module): |
| assert hasattr(observed_module, 'qconfig') |
| assert hasattr(observed_module, 'activation_post_process') |
| observed_module.linear.activation_post_process = \ |
| observed_module.activation_post_process |
| quantized = cls(nnq.Linear.from_float(observed_module.linear)) |
| return quantized |
| |
| class DynamicQuantCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_observed(cls, observed_module): |
| assert hasattr(observed_module, 'qconfig') |
| observed_module.linear.qconfig = observed_module.qconfig |
| quantized = cls(nnqd.Linear.from_float(observed_module.linear)) |
| return quantized |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| self.custom = CustomModule() |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.custom(x) |
| return x |
| |
| class RefM(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(3, 3) |
| self.linear2 = torch.nn.Linear(3, 3) |
| |
| def forward(self, x): |
| x = self.linear1(x) |
| x = self.linear2(x) |
| return x |
| |
| # instantiate M and RefM and align the parameters |
| original_m = M().eval() |
| original_ref_m = RefM().eval() |
| original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) |
| original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) |
| original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) |
| original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) |
| |
| a16_qconfig = QConfig( |
| activation=MinMaxObserver.with_args(dtype=torch.qint32, quant_min=0, quant_max=65536), |
| weight=default_weight_observer, |
| ) |
| test_configs = { |
| "static": (default_qconfig, StaticQuantCustomModule, 3), |
| "static_a16": (a16_qconfig, StaticQuantCustomModule, 3), |
| "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) |
| } |
| |
| for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]: |
| key = _get_quant_type_to_str(quant_type) |
| qconfig, quantized_module_class, num_observers = test_configs[key] |
| qconfig_dict = {"": qconfig} |
| if key == "static": |
| prepare_custom_config_dict = { |
| "float_to_observed_custom_module_class": { |
| "static": { |
| CustomModule: ObservedCustomModule |
| } |
| } |
| } |
| convert_custom_config_dict = { |
| "observed_to_quantized_custom_module_class": { |
| "static": { |
| ObservedCustomModule: quantized_module_class |
| } |
| } |
| } |
| else: |
| prepare_custom_config_dict = { |
| "non_traceable_module_class": [ |
| CustomModule |
| ] |
| } |
| convert_custom_config_dict = { |
| "observed_to_quantized_custom_module_class": { |
| "dynamic": { |
| CustomModule: quantized_module_class |
| } |
| } |
| } |
| |
| example_inputs = (torch.randn(3, 3),) |
| # check prepared model |
| m = prepare_fx( |
| copy.deepcopy(original_m), |
| qconfig_dict, |
| example_inputs=example_inputs, |
| prepare_custom_config=prepare_custom_config_dict) |
| # calibration |
| m(*example_inputs) |
| # all activation observers are inserted in the top level module |
| count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| |
| # check converted/quantized model |
| m = convert_fx( |
| m, |
| convert_custom_config=convert_custom_config_dict) |
| if quant_type == QuantType.STATIC: |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_module(nnq.Linear) : 1, |
| ns.call_method('dequantize') : 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| self.assertEqual(type(m.custom), quantized_module_class) |
| res = m(*example_inputs) |
| |
| # quantize the reference model |
| ref_m = prepare_fx( |
| copy.deepcopy(original_ref_m), qconfig_dict, example_inputs=example_inputs) |
| ref_m(*example_inputs) |
| ref_m = convert_fx(ref_m) |
| ref_res = ref_m(*example_inputs) |
| self.assertEqual(res, ref_res) |
| |
| @skipIfNoFBGEMM |
| def test_custom_module_class_input_has_multiple_users(self): |
| """ Tests that the flow still works when the input of custom module |
| has multiple users |
| """ |
| class CustomModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class ObservedCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_float(cls, float_module): |
| assert hasattr(float_module, 'qconfig') |
| observed = cls(float_module.linear) |
| observed.qconfig = float_module.qconfig |
| return observed |
| |
| class StaticQuantCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_observed(cls, observed_module): |
| assert hasattr(observed_module, 'qconfig') |
| assert hasattr(observed_module, 'activation_post_process') |
| observed_module.linear.activation_post_process = \ |
| observed_module.activation_post_process |
| quantized = cls(nnq.Linear.from_float(observed_module.linear)) |
| return quantized |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| self.custom = CustomModule() |
| |
| def forward(self, x0): |
| x1 = self.custom(x0) |
| x2 = self.linear(x0) |
| return x1 + x2 |
| |
| prepare_custom_config_dict = { |
| "float_to_observed_custom_module_class": { |
| "static": { |
| CustomModule: ObservedCustomModule |
| } |
| } |
| } |
| convert_custom_config_dict = { |
| "observed_to_quantized_custom_module_class": { |
| "static": { |
| ObservedCustomModule: StaticQuantCustomModule |
| } |
| } |
| } |
| m = M().eval() |
| example_inputs = (torch.randn(3, 3),) |
| m = prepare_fx( |
| m, |
| {"": default_qconfig}, |
| example_inputs=example_inputs, |
| prepare_custom_config=prepare_custom_config_dict) |
| # make sure it works |
| m = convert_fx( |
| m, |
| convert_custom_config=convert_custom_config_dict) |
| # make sure it runs |
| m(*example_inputs) |
| |
| @skipIfNoFBGEMM |
| def test_custom_module_class_input_has_duplicate_nodes(self): |
| """ Tests that the flow still works when the graph has |
| multiple nodes with the same custom module target. |
| """ |
| class CustomModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class ObservedCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_float(cls, float_module): |
| assert hasattr(float_module, 'qconfig') |
| observed = cls(float_module.linear) |
| observed.qconfig = float_module.qconfig |
| return observed |
| |
| class StaticQuantCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_observed(cls, observed_module): |
| assert hasattr(observed_module, 'qconfig') |
| assert hasattr(observed_module, 'activation_post_process') |
| observed_module.linear.activation_post_process = \ |
| observed_module.activation_post_process |
| quantized = cls(nnq.Linear.from_float(observed_module.linear)) |
| return quantized |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.custom = CustomModule() |
| |
| def forward(self, x0): |
| x1 = self.custom(x0) |
| x2 = self.custom(x0) |
| return x1 + x2 |
| |
| prepare_custom_config_dict = { |
| "float_to_observed_custom_module_class": { |
| "static": { |
| CustomModule: ObservedCustomModule |
| } |
| } |
| } |
| convert_custom_config_dict = { |
| "observed_to_quantized_custom_module_class": { |
| "static": { |
| ObservedCustomModule: StaticQuantCustomModule |
| } |
| } |
| } |
| m = M().eval() |
| example_inputs = (torch.randn(3, 3),) |
| m = prepare_fx( |
| m, |
| {"": default_qconfig}, |
| example_inputs=example_inputs, |
| prepare_custom_config=prepare_custom_config_dict) |
| # make sure it works |
| m = convert_fx( |
| m, |
| convert_custom_config=convert_custom_config_dict) |
| # make sure it runs |
| m(*example_inputs) |
| |
| @skipIfNoFBGEMM |
| def test_non_traceable_module(self): |
| class NonTraceable(torch.nn.Module): |
| def forward(self, x): |
| for k in x.keys(): |
| print(x[k]) |
| return x |
| |
| class NonTraceable2(torch.nn.Module): |
| def forward(self, x): |
| # data dependent control flow is not traceable |
| for i in x: |
| print(i) |
| return x |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.m1 = NonTraceable() |
| self.m2 = NonTraceable2() |
| |
| def forward(self, x): |
| x = self.m1(x) |
| x = self.m2(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"": default_qconfig} |
| prepare_custom_config_dict = { |
| "non_traceable_module_name": [ |
| "m1" |
| ], |
| "non_traceable_module_class": [ |
| NonTraceable2 |
| ] |
| } |
| m = prepare_fx( |
| m, qconfig_dict, |
| example_inputs=({"key": torch.randn(1)},), |
| prepare_custom_config=prepare_custom_config_dict) |
| |
| node_occurrence = { |
| ns.call_module(NonTraceable) : 1, |
| ns.call_module(NonTraceable2) : 1, |
| } |
| # make sure these modules are not traced |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_prepared_model_deepcopy(self): |
| """Ensures that copy.deepcopy works correctly on a prepared model. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| self._foobar = 'foobar' |
| self.foobar2 = 'foobar2' |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x |
| |
| m = M() |
| m.eval() |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| example_inputs = (torch.randn(4, 1, 4, 4),) |
| prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| # calibrate |
| prepared(*example_inputs) |
| # copy |
| prepared_copy = copy.deepcopy(prepared) |
| # quantize, should run with no errors |
| quantized = convert_fx(prepared_copy) |
| |
| def test_quantized_model_type(self): |
| """ Test state_dict and deepcopy works properly in the quantized model |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| example_inputs = (torch.rand(8, 5),) |
| m = M().eval() |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| m = convert_fx(m) |
| # test deepcopy |
| m_copy = copy.deepcopy(m) |
| self.assertEqual(m_copy(*example_inputs), m(*example_inputs)) |
| |
| # test state_dict |
| state_dict = m.state_dict() |
| m_new = M().eval() |
| m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs) |
| m_new = convert_fx(m_new) |
| m_new.load_state_dict(state_dict) |
| self.assertEqual(m_new(*example_inputs), m(*example_inputs)) |
| |
| def test_dequantize(self): |
| r""" Test to make sure dequantize node are placed before |
| non-quantizable node |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| self.act = torch.nn.GELU() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return self.act(x) |
| |
| data = torch.rand(5, 1, 3, 3, dtype=torch.float) |
| for quant_type in self.static_quant_types: |
| node_list = [ |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.GELU), |
| ] |
| self.checkGraphModeFxOp( |
| M().eval(), (data,), quant_type, expected_node_list=node_list) |
| |
| def test_sequential(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.convs = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 1, 1), |
| torch.nn.Conv2d(1, 1, 1) |
| ) |
| |
| def forward(self, x): |
| x = self.convs(x) |
| return x |
| |
| data = torch.rand(5, 1, 3, 3, dtype=torch.float) |
| for quant_type in self.static_quant_types: |
| node_list = [ |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ] |
| self.checkGraphModeFxOp( |
| M().eval(), (data,), quant_type, expected_node_list=node_list) |
| |
| def _test_quantized_inputs_outputs( |
| self, prepare_custom_config_dict, prepare_count_check, |
| convert_count_check): |
| """ |
| Test the option to have inputs and outputs of the graph quantized |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| # quantized input, quantized output |
| m = M() |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| example_inputs = (torch.randn(1, 1, 4, 4),) |
| m.eval() |
| mp = torch.ao.quantization.quantize_fx.prepare_fx( |
| m, qconfig_dict, |
| example_inputs=example_inputs, |
| prepare_custom_config=prepare_custom_config_dict) |
| self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) |
| mp(*example_inputs) |
| mq = torch.ao.quantization.quantize_fx.convert_fx(mp) |
| self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) |
| |
| def test_quantized_input_quantized_output(self): |
| prepare_custom_config_dict = { |
| 'input_quantized_idxs': [0], 'output_quantized_idxs': [0]} |
| prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 2, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 0, |
| ns.call_method('dequantize'): 0, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| def test_fp32_input_quantized_output(self): |
| prepare_custom_config_dict = { |
| 'output_quantized_idxs': [0]} |
| prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 3, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method('dequantize'): 0, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| def test_quantized_input_fp32_output(self): |
| prepare_custom_config_dict = { |
| 'input_quantized_idxs': [0]} |
| prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 2, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 0, |
| ns.call_method('dequantize'): 1, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| def test_fp32_input_fp32_output(self): |
| prepare_custom_config_dict = {} |
| prepare_count_check = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 3, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method('dequantize'): 1, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| @skipIfNoFBGEMM |
| def test_convtranspose_per_channel_fails_early(self): |
| r""" |
| Verifies that attempting to quantize a ConvTranspose module with per-Channel |
| weight observers fails in the prepare step, as opposed to the convert step. |
| """ |
| m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) |
| m.eval() |
| qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} |
| with self.assertRaises(AssertionError) as context: |
| mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) |
| self.assertTrue( |
| str(context.exception) == |
| 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') |
| |
| @skipIfNoFBGEMM |
| def test_qparams_buffers(self): |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| self.mods2 = Linear() |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = self.mods2(x) |
| return x |
| |
| model = M().eval() |
| qconfig_dict = {"": default_qconfig} |
| example_inputs = (torch.rand(5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| keys = m.state_dict().keys() |
| quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0 |
| for k in keys: |
| if 'input_scale' in k: |
| quant_scale_count = quant_scale_count + 1 |
| elif 'input_zero_point' in k: |
| quant_zero_point = quant_zero_point + 1 |
| elif 'scale' in k: |
| scale_count = scale_count + 1 |
| elif 'zero_point' in k: |
| zero_point_count = zero_point_count + 1 |
| |
| # Expect each quantized linear op to have a scale and zero point |
| self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict") |
| self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict") |
| m(*example_inputs) |
| # ensure it is scriptable |
| scripted = torch.jit.script(m) |
| scripted_keys = scripted.state_dict().keys() |
| scripted.mods1_0_packed_weight_0 = m.state_dict()["mods1_0_packed_weight_0"] |
| non_packed_weight_keys = [key for key in keys if "_packed_weight" not in key] |
| self.assertTrue( |
| set(scripted_keys) == set(non_packed_weight_keys), |
| "Expected the scripted model to preserve the state_dict for non-packed weight attributes") |
| # TODO: probably don't want to hardcode the attribute names, since they are generated |
| for attr_name in [ |
| "mods1_0_input_scale_0", "mods1_0_input_zero_point_0", |
| "mods1_0_scale_1", "mods1_0_zero_point_1", |
| "mods1_1_scale_1", "mods1_1_zero_point_1", |
| "mods2_scale_1", "mods2_zero_point_1"]: |
| self.assertTrue(hasattr(m, attr_name), attr_name + " not found.") |
| |
| @skipIfNoFBGEMM |
| def test_packed_weight_fused_op(self): |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return F.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| self.mods2 = Linear() |
| self.relu = F.relu |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = self.mods2(x) |
| x = self.relu(x) |
| return x |
| |
| model = M().eval() |
| example_inputs = (torch.rand(5, 5),) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| assert hasattr(m, "mods1_0_packed_weight_0") |
| assert hasattr(m, "mods1_1_packed_weight_0") |
| assert hasattr(m, "mods2_packed_weight_0") |
| |
| @skipIfNoFBGEMM |
| def test_mul_add_fp16_config(self): |
| with override_quantized_engine('fbgemm'): |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| self.mods2 = Linear() |
| |
| def forward(self, x): |
| x = x * 5 |
| x = x + 5 |
| x = self.mods1(x) |
| x = self.mods2(x) |
| return x |
| model = M().eval() |
| qconfig_dict = {"": float16_dynamic_qconfig} |
| example_inputs = (torch.rand(5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m = convert_fx(m) |
| # make sure it runs |
| m(*example_inputs) |
| |
| def test_getattr_with_nontensor_result(self): |
| """ |
| Verifies that binary ops get quantized correctly if some |
| of the args are nodes but not Tensors, such as an `x.ndim` |
| pattern. |
| """ |
| class M1(torch.nn.Module): |
| def forward(self, x): |
| dims = x.ndim |
| dims_sub = dims - 1 |
| dims_sub2 = dims_sub - 1 |
| x = torch.add(x, dims_sub2) |
| return x |
| |
| class M2(torch.nn.Module): |
| def forward(self, x): |
| dims = x.ndim |
| dims_sub = dims - 2 |
| mul = [1] * dims_sub |
| dims_list = [-1, x.size(1)] + mul |
| x = x.view(dims_list) |
| return x |
| |
| class M3(torch.nn.Module): |
| def forward(self, x): |
| shape = x.shape |
| x = x.view(shape) |
| return x |
| |
| for cls in (M1, M2, M3): |
| m = cls().eval() |
| example_inputs = (torch.rand(4, 4, 4, 4),) |
| m(*example_inputs) |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| mp(torch.rand(4, 4, 4, 4)) |
| mc = convert_fx(mp) |
| |
| class _NonReferenceTestModel(nn.Module): |
| def __init__(self, func, lin_in, lin_out): |
| super().__init__() |
| self.conv1 = nn.Conv2d(3, 6, 5) |
| self.pool = nn.MaxPool2d(2, 2) |
| self.lin = nn.Linear(lin_in, lin_out) |
| self.func = func |
| |
| def forward(self, x, y, z): |
| x = self.pool(F.relu(self.conv1(x))) |
| x = torch.flatten(x, 1) |
| x = self.func(x, y, z) |
| x = self.lin(x) |
| return x |
| |
| # This function looks at the node specified by the NodeInfo in the key of |
| # node_info_to_non_tensor_args and checks that the args at specified indices |
| # are not observed (since they are non tensors). If the args at those indices |
| # are a tuple/list (which do not show up as nodes) the function checks the |
| # individual elements of the tuple/list recursively. |
| def _check_not_observed(self, model, node_info_to_non_tensor_args): |
| |
| # this is a helper function (for easier recursion) that checks whether |
| # arg_node is observed |
| def _check_node_not_observed(model, arg_node, node): |
| if isinstance(arg_node, (tuple, list)): |
| for new_node in arg_node: |
| _check_node_not_observed(model, new_node, node) |
| elif arg_node.op == "call_module": |
| self.assertTrue( |
| not _is_activation_post_process(getattr(model, arg_node.target)), |
| f"Arg: {arg_node} of node: {node} is observed but is not a float tensor", |
| ) |
| |
| for node in model.graph.nodes: |
| indices = node_info_to_non_tensor_args.get( |
| NodeInfo(node.op, node.target), [] |
| ) |
| for index in indices: |
| if index < len(node.args): |
| arg_node = node.args[index] |
| _check_node_not_observed(model, arg_node, node) |
| |
| # This test checks that the model gets prepared correct, doesn't have observers |
| # on specific ops (see _check_not_observed) and that the prepared model runs |
| def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args): |
| model.eval() |
| qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} |
| prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args)) |
| self._check_not_observed(prepared_model, node_info_to_non_tensor_args) |
| prepared_model(*args) |
| |
| def test_masked_fill_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.masked_fill(y, z) |
| |
| model = self._NonReferenceTestModel(func, 1176, 1) |
| args = [torch.randn(5, 3, 32, 32), torch.randn(1176) > 0, 0.1] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "masked_fill"): [1, 2]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_permute_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.permute(y, z) |
| |
| model = self._NonReferenceTestModel(func, 1176, 1) |
| args = [torch.randn(5, 3, 32, 32), 0, 1] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "permute"): [1, 2]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_repeat_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.repeat(y, z) |
| |
| model = self._NonReferenceTestModel(func, 1176, 1) |
| args = [torch.randn(5, 3, 32, 32), 2, 1] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "repeat"): [1, 2]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_reshape_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.reshape(-1, y) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), 5, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [2]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_size_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.reshape((-1, x.size(y))) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), 0, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "size"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_transpose_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.transpose(y, z) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), 0, 1] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_torch_transpose_nontensor_args_not_observed(self): |
| # TODO: make torch.transpose traceable by fx when using |
| # variable nontensor arguments |
| # func = lambda x, y, z: torch.transpose(x, y, z) # error |
| def func(x, y, z): |
| return torch.transpose(x, 0, 1) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| node_info_to_non_tensor_args = { |
| NodeInfo("call_method", torch.transpose): [1, 2] |
| } |
| args = [torch.randn(5, 3, 32, 32), 0, 1] |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_unsqueeze_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.unsqueeze(y) |
| |
| model = self._NonReferenceTestModel(func, 1176, 1) |
| args = [torch.randn(5, 3, 32, 32), 1, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_unsqueeze__nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.unsqueeze_(y) |
| |
| model = self._NonReferenceTestModel(func, 1176, 1) |
| args = [torch.randn(5, 3, 32, 32), 1, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze_"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_torch_unsqueeze_nontensor_args_not_observed(self): |
| # TODO: make torch.unsqueeze scriptable by fx when using |
| # variable nontensor arguments |
| # func = lambda x, y, z: torch.unsqueeze(x, y) # error |
| def func(x, y, z): |
| return torch.unsqueeze(x, 1) |
| |
| model = self._NonReferenceTestModel(func, 1176, 1) |
| args = [torch.randn(5, 3, 32, 32), 1, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", torch.unsqueeze): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_view_nontensor_args_not_observed(self): |
| def func(x, y, z): |
| return x.view(-1, y) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), 5, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "view"): [2]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_propagate_dtypes_for_known_nodes_list_args(self): |
| def func(x, y, z): |
| return x.reshape(y) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), [-1, 5], None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_propagate_dtypes_for_known_nodes_split_list_args(self): |
| def func(x, y, z): |
| return x.reshape([y, z]) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), -1, 5] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_propagate_dtypes_for_known_nodes_tuple_args(self): |
| def func(x, y, z): |
| return x.reshape(y) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), (-1, 5), None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_propagate_dtypes_for_known_nodes_split_tuple_args(self): |
| def func(x, y, z): |
| return x.reshape((y, z)) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), -1, 5] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_propagate_dtypes_for_known_nodes_dict_args(self): |
| def func(x, y, z): |
| return x.transpose(y["first"], y["second"]) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), {"first": 0, "second": 1}, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_propagate_dtypes_for_known_nodes_dict_tuple_args(self): |
| class reshape_module(nn.Module): |
| def forward(self, x, y, z): |
| return x.reshape(y["shape"]) |
| |
| model = self._NonReferenceTestModel(reshape_module(), 5, 1) |
| args = [torch.randn(5, 3, 32, 32), {"shape": (-1, 5)}, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_propagate_dtypes_for_known_nodes_dict_split_tuple_args(self): |
| def func(x, y, z): |
| return x.reshape((y["first"], y["second"])) |
| |
| model = self._NonReferenceTestModel(func, 5, 1) |
| args = [torch.randn(5, 3, 32, 32), {"first": -1, "second": 5}, None] |
| node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1]} |
| self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) |
| |
| def test_assert_on_size_after_quant_layer(self): |
| """ |
| Verifies that calculating a size of a quantized tensor works |
| correctly in quantization passes. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| torch._assert(x.size(1) == 1, 'foobar') |
| return x |
| |
| m = M().eval() |
| example_inputs = (torch.rand(4, 1, 4, 4),) |
| m(*example_inputs) |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| mp(*example_inputs) |
| mc = convert_fx(mp) |
| mc(*example_inputs) |
| |
| def test_fp32_sum(self): |
| """ |
| Verifies that fp32 sum works correctly if it's before or after |
| quantized layers. |
| """ |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = torch.stack([x]) |
| x = torch.sum(x) |
| return x |
| |
| class M2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| self.conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x1 = torch.stack([x]) |
| x1 = torch.sum(x1, dim=0) |
| x2 = self.conv2(x1) |
| return x2 |
| |
| for cls in (M1, M2): |
| m = cls().eval() |
| example_inputs = (torch.rand(4, 1, 4, 4),) |
| m(*example_inputs) |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| mp(*example_inputs) |
| mc = convert_fx(mp) |
| mc(*example_inputs) |
| |
| def test_fusion_pattern_unquantized(self): |
| """ |
| Ensure that leaving a possible fusion pattern of multiple nodes |
| unquantized runs through the APIs without errors. |
| """ |
| class Child(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = torch.add(x, 1.0) |
| x = torch.nn.functional.relu(x) |
| return x |
| |
| class Parent(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.child = Child() |
| self.conv = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.child(x) |
| x = self.conv(x) |
| return x |
| |
| m = Parent().eval() |
| qconfig_dict = { |
| '': torch.ao.quantization.default_qconfig, |
| 'module_name': [ |
| ('child', None), |
| ], |
| } |
| example_inputs = (torch.rand(1, 1, 1, 1),) |
| mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| mp(*example_inputs) |
| mc = convert_fx(mp) |
| |
| def test_state_dict(self): |
| """ Make sure packed params appear in state_dict |
| """ |
| |
| # test linear packed weight |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.rand(4, 30) |
| self.b = torch.rand(4) |
| |
| def forward(self, x): |
| return F.linear(x, self.w, self.b) |
| |
| m = M1().eval() |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),)) |
| m = convert_fx(m) |
| state_dict = m.state_dict() |
| self.assertTrue("_packed_weight_0" in state_dict) |
| |
| # test conv packed weight |
| class M2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.rand(3, 3, 3, 3) |
| self.b = torch.rand(3) |
| self.stride = (1, 1) |
| self.padding = (0, 0) |
| self.dilation = (1, 1) |
| self.groups = 1 |
| |
| def forward(self, x): |
| return F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) |
| |
| m = M2().eval() |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) |
| m = convert_fx(m) |
| state_dict = m.state_dict() |
| self.assertTrue("_packed_weight_0" in state_dict) |
| |
| # test load |
| ref_weight, ref_bias = torch.ops.quantized.conv2d_unpack(state_dict["_packed_weight_0"]) |
| data = torch.rand(1, 3, 5, 5) |
| ref_res = m(data) |
| m = M2().eval() |
| m = prepare_fx(m, qconfig_dict, (data,)) |
| m = convert_fx(m) |
| res = m(data) |
| weight, bias = m._packed_weight_0.unpack() |
| # check that random model weight/bias does not match ref weight/bias |
| self.assertNotEqual(weight, ref_weight) |
| self.assertNotEqual(bias, ref_bias) |
| self.assertNotEqual(res, ref_res) |
| m.load_state_dict(state_dict) |
| |
| def checkModel(m, data, ref_weight, ref_bias, ref_res): |
| res = m(data) |
| weight, bias = m._packed_weight_0.unpack() |
| # check that weight/bias matches after load the state_dict |
| self.assertEqual(weight, ref_weight) |
| self.assertEqual(bias, ref_bias) |
| self.assertEqual(res, ref_res) |
| |
| checkModel(m, data, ref_weight, ref_bias, ref_res) |
| |
| # Test save to disk and load back |
| m = M2().eval() |
| m = prepare_fx(m, qconfig_dict, example_inputs=(data,)) |
| m = convert_fx(m) |
| m.load_state_dict(state_dict) |
| with TemporaryFileName() as fname: |
| torch.save(m.state_dict(), fname) |
| m.load_state_dict(torch.load(fname)) |
| |
| checkModel(m, data, ref_weight, ref_bias, ref_res) |
| |
| @skipIfNoFBGEMM |
| def test_preserve_qconfig(self): |
| """ |
| Test to make sure the temporary config option to preserve qconfig attributes |
| in the model works |
| """ |
| with override_quantized_engine('fbgemm'): |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| self.mods2 = torch.nn.Sigmoid() |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = self.mods2(x) |
| return x |
| |
| model = M().eval() |
| qconfig_dict = { |
| "object_type": [ |
| (torch.nn.functional.linear, float16_dynamic_qconfig), |
| ], |
| } |
| example_inputs = (torch.rand(5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m, _remove_qconfig=False) |
| |
| self.assertTrue(hasattr(m.mods2, 'qconfig')) |
| |
| def test_not_used(self): |
| """ Test quantizing a not used value""" |
| |
| class M(torch.nn.Module): |
| def forward(self, x): |
| x = x + x |
| x.sigmoid_() |
| return x |
| |
| m = M().eval() |
| qconfig_mapping = get_default_qconfig_mapping().set_global(float16_static_qconfig) |
| # make sure quantization runs |
| m = prepare_fx(m, qconfig_mapping, example_inputs=(torch.randn(1),)) |
| m = convert_fx(m) |
| |
| def test_qparams_fqn(self): |
| """ Test that the FQN of input_scale/zero_point is set |
| to that of first linear use. """ |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| |
| def forward(self, x): |
| x = torch.cat((x,), 1) |
| tmp = x.size() |
| x = self.mods1(x) |
| y = x * tmp[0] |
| return y |
| |
| model = M().eval() |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (torch.nn.functional.linear, default_qconfig), |
| (torch.nn.functional.relu, default_qconfig), |
| ], |
| } |
| example_inputs = (torch.rand(5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| keys = m.state_dict().keys() |
| m(torch.randn(5, 5)) |
| # TODO: probably don't want to hardcode the attribute names, since they are generated |
| for attr_name in [ |
| "mods1_0_input_scale_0", "mods1_0_input_zero_point_0", |
| "mods1_0_scale_0", "mods1_0_zero_point_0", |
| "mods1_1_scale_0", "mods1_1_zero_point_0"]: |
| self.assertTrue(hasattr(m, attr_name), attr_name + " not found.") |
| |
| def test_no_obs_between_unmatched_node_and_copy_node(self): |
| """ |
| Verifies that an observer is not inserted between an unmatched |
| node and a node matched to CopyNodeQuantizeHandler. This is done |
| because observers require activations to be Tensors, and there is |
| no guarantee that an output of an unmatched node is a Tensor. |
| """ |
| |
| class M(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = _user_func_with_complex_return_type(x) |
| x1 = x[0] + 1 |
| return x1, x[1] |
| |
| m = M().eval() |
| |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| example_inputs = (torch.randn(4, 4, 4, 4),) |
| mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| # if an observer is inserted after _user_func_with_complex_return_type, |
| # the following call will fail |
| mp(*example_inputs) |
| mc = convert_fx(mp) |
| mc(*example_inputs) |
| |
| def test_fold_quant_dequant(self): |
| """ Test that the sequence of quant-dequant nodes in the |
| graph, get folded and we erase the extra dequant nodes. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| x = torch.cat((x,), 1) |
| tmp = x.size() |
| x = torch.nn.functional.linear(x, self.w, self.b) |
| y = x * tmp[0] |
| return y |
| |
| model = M().eval() |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (torch.nn.functional.linear, default_qconfig), |
| ], |
| } |
| example_inputs = (torch.rand(5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| keys = m.state_dict().keys() |
| m(*example_inputs) |
| dequant = 0 |
| quant = 0 |
| for n in m.graph.nodes: |
| if n.op == "call_method" and n.target == "dequantize": |
| dequant = dequant + 1 |
| if n.op == "call_function" and n.target == torch.quantize_per_tensor: |
| quant = quant + 1 |
| self.assertEqual(dequant, 1) |
| self.assertEqual(quant, 1) |
| |
| def test_quant_output_always_observed(self): |
| """ |
| If the output is hardcoded to be quantized, ensure that |
| there is always an observer, even if the last non-output node is not |
| quantizeable. |
| """ |
| qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} |
| prepare_custom_config_dict = {'output_quantized_idxs': [0]} |
| example_inputs = (torch.randn(4, 1, 4, 4),) |
| |
| # non-quantizeable node, quantized output |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.identity = torch.nn.Identity() |
| |
| def forward(self, x): |
| x = self.identity(x) |
| return x |
| |
| m1 = M1() |
| self.checkGraphModeFxOp( |
| m1, example_inputs, QuantType.QAT, |
| prepare_expected_node_occurrence={ |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, |
| }, |
| expected_node_occurrence={ |
| ns.call_function(torch.quantize_per_tensor): 1, |
| }, |
| prepare_custom_config=prepare_custom_config_dict) |
| |
| # quantizeable node, quantized output |
| class M2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x |
| |
| m2 = M2() |
| self.checkGraphModeFxOp( |
| m2, example_inputs, QuantType.QAT, |
| prepare_expected_node_occurrence={ |
| # one for weights, one for activations |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, |
| }, |
| expected_node_occurrence={ |
| ns.call_function(torch.quantize_per_tensor): 1, |
| }, |
| prepare_custom_config=prepare_custom_config_dict) |
| |
| # quantizeable node, quantized dictionary output |
| class M3(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return {"output": x} |
| |
| m3 = M3() |
| self.checkGraphModeFxOp( |
| m3, example_inputs, QuantType.QAT, |
| prepare_expected_node_occurrence={ |
| # one for weights, one for activations |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, |
| }, |
| expected_node_occurrence={ |
| ns.call_function(torch.quantize_per_tensor): 1, |
| }, |
| prepare_custom_config=prepare_custom_config_dict) |
| |
| def test_deepcopy_preserve_attributes(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.attr = 3 |
| |
| def forward(self, x): |
| return x |
| |
| m = M().eval() |
| m = prepare_fx( |
| m, |
| {"": default_qconfig}, |
| example_inputs=(torch.randn(1),), |
| prepare_custom_config={"preserved_attributes": ["attr"]}) |
| # preserved attributes are also stored in meta so that it doesn't get lost |
| # during deepcopy |
| self.assertTrue(hasattr(m, "attr")) |
| self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) |
| m2 = copy.deepcopy(m) |
| self.assertTrue(hasattr(m2, "attr")) |
| self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) |
| m = convert_fx(m, convert_custom_config={"preserved_attributes": ["attr"]}) |
| self.assertTrue(hasattr(m, "attr")) |
| self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) |
| m2 = copy.deepcopy(m) |
| self.assertTrue(hasattr(m2, "attr")) |
| self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) |
| |
| def test_output_lists_and_dicts(self): |
| """Verify that specifying complicated output types does not crash. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return {'foo': [x]}, [{'foo': [[x]]}] |
| |
| m = M().eval() |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) |
| mc = convert_fx(mp) |
| |
| def test_shape_followed_by_quantized_op(self): |
| """ Make sure that shape does not dequantize |
| the Tensor before the next operator |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2) |
| self.conv2 = torch.nn.Conv2d(2, 2, 2) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| s = x.shape |
| torch._assert(s == x.shape, "") |
| x = self.conv2(x) |
| return x |
| |
| # make sure quantization runs |
| m = M().eval() |
| example_inputs = (torch.randn(2, 2, 4, 4),) |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_trace_quantize_per_tensor(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x |
| |
| m = M().eval() |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),)) |
| m = convert_fx(m) |
| # Make sure this runs without error |
| m = torch.fx.Transformer(m).transform() |
| |
| def test_copy_node_has_shared_actpp_instance(self): |
| """ Test the output of CopyNode to have the same |
| observer/fake_quant instance as the input |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.avgpool2d = torch.nn.AvgPool2d(kernel_size=3) |
| |
| def forward(self, x): |
| x = self.avgpool2d(x) |
| return x |
| |
| for quant_type in self.static_quant_types: |
| m = M() |
| # Checks that we have an observer for both input and output |
| occurrence_map = { |
| QuantType.STATIC: { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 2 |
| }, |
| QuantType.QAT: { |
| ns.call_module(torch.ao.quantization.FakeQuantize): 2 |
| } |
| } |
| if quant_type == QuantType.QAT: |
| m.train() |
| prepare = prepare_qat_fx |
| qconfig = default_qat_qconfig |
| actpp_module_class = torch.ao.quantization.FakeQuantize |
| else: |
| m.eval() |
| prepare = prepare_fx |
| qconfig = default_qconfig |
| actpp_module_class = torch.ao.quantization.MinMaxObserver |
| |
| example_inputs = (torch.randn(1, 3, 3, 3),) |
| m = prepare(m, {"": qconfig}, example_inputs=example_inputs) |
| # check that there is a duplicated observer instance |
| actpp_module_count = 0 |
| for name, module in m.named_modules(remove_duplicate=False): |
| if isinstance(module, actpp_module_class): |
| actpp_module_count += 1 |
| self.assertEqual(actpp_module_count, 2) |
| |
| actpp_module_count = 0 |
| for name, module in m.named_modules(): |
| if isinstance(module, actpp_module_class): |
| actpp_module_count += 1 |
| self.assertEqual(actpp_module_count, 1) |
| |
| m_copy = copy.deepcopy(m) |
| m = convert_fx(m) |
| m_reference = convert_to_reference_fx(m_copy) |
| |
| # checks for non-reference quantized model |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(torch.nn.AvgPool2d), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, expected_node_list=node_list) |
| |
| # checks for reference quantized model, for copy nodes we'll have |
| # dequant - copy_node - quant patterns which will be fused later |
| # in the backend lowering step |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 2 |
| } |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ns.call_module(torch.nn.AvgPool2d), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m_reference, expected_node_occurrence=node_occurrence, expected_node_list=node_list) |
| |
| def test_linear_qint8_activation(self): |
| """Test support for qint8 activation in reference pattern |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 2, 2, 2) |
| self.linear = torch.nn.Linear(8, 5) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = torch.flatten(x, 1) |
| x = self.linear(x) |
| return x |
| |
| m = M().eval() |
| example_inputs = (torch.rand(2, 1, 5, 5),) |
| m = prepare_fx( |
| m, |
| {"": torch.ao.quantization.QConfig( |
| activation=torch.ao.quantization.HistogramObserver.with_args( |
| qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 |
| ), weight=torch.ao.quantization.default_per_channel_weight_observer)}, |
| example_inputs=example_inputs) |
| m = convert_to_reference_fx(m) |
| m(*example_inputs) |
| |
| def test_preserve_tuple(self): |
| """ Test tuple input type is preserved |
| """ |
| |
| class LSTM(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm = nn.LSTM(50, 50, 1) |
| |
| def forward(self, inputs: torch.Tensor, state: List[torch.Tensor]): |
| h = state[0] |
| c = state[1] |
| return self.lstm(inputs, (h, c)) |
| |
| m = LSTM().eval() |
| example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50)) |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| # make sure the arg[1] of lstm module is a tuple |
| for n in m.graph.nodes: |
| if n.target == "lstm": |
| self.assertEqual(type(n.args[1]), tuple) |
| |
| def _test_static_lstm_helper(self, model, prepare_node_occurrence, convert_node_occurrence): |
| """ |
| Helper method to validate the graph of a model with static LSTM. |
| """ |
| qconfig_mapping = get_default_qconfig_mapping() |
| prepare_custom_config = PrepareCustomConfig() \ |
| .set_float_to_observed_mapping(torch.nn.LSTM, torch.ao.nn.quantizable.LSTM) |
| convert_custom_config = ConvertCustomConfig() \ |
| .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, torch.ao.nn.quantized.LSTM) |
| example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) |
| |
| model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config) |
| self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence) |
| model(*example_inputs) |
| |
| model = convert_fx(model, convert_custom_config=convert_custom_config) |
| self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence) |
| model(*example_inputs) |
| |
| def test_static_lstm(self): |
| """ |
| Test statically quantized custom module LSTM followed by ops that consume individual |
| tensors of the output tuple. |
| """ |
| class MyModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm = nn.LSTM(50, 50, 1) |
| self.linear1 = nn.Linear(50, 10) |
| self.linear2 = nn.Linear(50, 10) |
| self.linear3 = nn.Linear(50, 10) |
| |
| def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): |
| (out, (h0_out, c0_out)) = self.lstm(inputs, (h0, c0)) |
| out = self.linear1(out) |
| h0_out = self.linear2(h0_out) |
| c0_out = self.linear3(c0_out) |
| return (out, (h0_out, c0_out)) |
| |
| m = MyModel() |
| prepare_node_occurrence = { |
| ns.call_module(torch.ao.nn.quantizable.LSTM): 1, |
| } |
| convert_node_occurrence = { |
| ns.call_module(torch.ao.nn.quantized.LSTM): 1, |
| ns.call_function(torch.quantize_per_tensor): 3, |
| # lstm[0].dequantize() |
| # lstm[1][0].dequantize() |
| # lstm[1][1].dequantize() |
| ns.call_method("dequantize"): 3, |
| # lstm[0], lstm[1], lstm[1][0], lstm[1][1] |
| ns.call_function(operator.getitem): 4, |
| # No tuples are consumed |
| ns.call_function(tuple): 0, |
| } |
| self._test_static_lstm_helper(m, prepare_node_occurrence, convert_node_occurrence) |
| |
| def test_static_lstm_consume_tuple(self): |
| """ |
| Test statically quantized custom module LSTM followed by a module that consumes the |
| output tuple, either as a whole or part of it. |
| """ |
| class ModuleAfterLSTM(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.identity = torch.nn.Identity() |
| |
| def forward(self, x): |
| return self.identity(x) |
| |
| class ConsumeWholeTuple(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm = nn.LSTM(50, 50, 1) |
| self.module_after_lstm = ModuleAfterLSTM() |
| |
| def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): |
| x = self.lstm(inputs, (h0, c0)) |
| x = self.module_after_lstm(x) # consume tuple (output, (hidden0, hidden1)) |
| return x |
| |
| class ConsumeHiddenTuple(ConsumeWholeTuple): |
| def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): |
| x = self.lstm(inputs, (h0, c0)) |
| x = self.module_after_lstm(x[1]) # consume tuple (hidden0, hidden1) |
| return x |
| |
| # Test consuming the whole tuple (output, (hidden0, hidden1)) |
| m1 = ConsumeWholeTuple() |
| prepare_node_occurrence = { |
| ns.call_module(torch.ao.nn.quantizable.LSTM): 1, |
| } |
| convert_node_occurrence1 = { |
| ns.call_module(torch.ao.nn.quantized.LSTM): 1, |
| ns.call_function(torch.quantize_per_tensor): 3, |
| # lstm[0].dequantize() |
| # lstm[1][0].dequantize() |
| # lstm[1][1].dequantize() |
| ns.call_method("dequantize"): 3, |
| # lstm[0], lstm[1], lstm[1][0], lstm[1][1] |
| ns.call_function(operator.getitem): 4, |
| # tuple(output_dq, tuple(hidden0_dq, hidden1_dq)) |
| ns.call_function(tuple): 2, |
| } |
| self._test_static_lstm_helper(m1, prepare_node_occurrence, convert_node_occurrence1) |
| |
| # Test consuming just the hidden tuple (hidden0, hidden1) |
| m2 = ConsumeHiddenTuple() |
| convert_node_occurrence2 = { |
| ns.call_module(torch.ao.nn.quantized.LSTM): 1, |
| ns.call_function(torch.quantize_per_tensor): 3, |
| # lstm[1][0].dequantize() |
| # lstm[1][1].dequantize() |
| ns.call_method("dequantize"): 2, |
| # lstm[1], lstm[1][0], lstm[1][1] |
| ns.call_function(operator.getitem): 3, |
| # tuple(hidden0_dq, hidden1_dq) |
| ns.call_function(tuple): 1, |
| } |
| self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2) |
| |
| def test_static_lstm_with_custom_fixed_qparams(self): |
| """ |
| Test statically quantized LSTM with custom fixed qparams assigned to each of the |
| inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM` |
| and use the child class in the custom module mapping. |
| """ |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.my_lstm = torch.nn.LSTM(50, 50, 1) |
| |
| def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): |
| x = self.my_lstm(inputs, (h0, c0)) |
| return x |
| |
| # Construct a BackendConfig that supports qint32 for certain ops |
| # TODO: build a BackendConfig from scratch instead of modifying an existing one |
| qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32) |
| my_backend_config = get_qnnpack_backend_config() |
| for config in my_backend_config.configs: |
| if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]: |
| config.add_dtype_config(qint32_dtype_config) |
| |
| class UserObservedLSTM(torch.ao.nn.quantizable.LSTM): |
| """ |
| Example of user provided LSTM implementation that assigns fixed qparams |
| to the inner ops. |
| """ |
| @classmethod |
| def from_float(cls, float_lstm): |
| assert isinstance(float_lstm, cls._FLOAT_MODULE) |
| # uint16, [-16, 16) |
| linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32) |
| # uint16, [0, 1) |
| sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32) |
| # uint16, [-1, 1) |
| tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32) |
| # int16, [-16, 16) |
| cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32) |
| # uint8, [-1, 1) |
| hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8) |
| example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50))) |
| return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts( |
| float_lstm=float_lstm, |
| example_inputs=example_inputs, |
| backend_config=my_backend_config, |
| linear_output_obs_ctr=linear_output_obs_ctr, |
| sigmoid_obs_ctr=sigmoid_obs_ctr, |
| tanh_obs_ctr=tanh_obs_ctr, |
| cell_state_obs_ctr=cell_state_obs_ctr, |
| hidden_state_obs_ctr=hidden_state_obs_ctr, |
| ) |
| |
| class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM): |
| """ |
| Example of user provided LSTM implementation that produces a reference |
| quantized module from a `UserObservedLSTM`. |
| """ |
| @classmethod |
| def from_observed(cls, observed_lstm): |
| assert isinstance(observed_lstm, cls._FLOAT_MODULE) |
| return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module( |
| observed_lstm=observed_lstm, |
| backend_config=my_backend_config, |
| ) |
| |
| # FX graph mode quantization |
| m = MyModel() |
| qconfig_mapping = get_default_qconfig_mapping("qnnpack") |
| example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) |
| prepare_custom_config = PrepareCustomConfig() \ |
| .set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM) |
| convert_custom_config = ConvertCustomConfig() \ |
| .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM) |
| prepared = prepare_fx( |
| m, |
| qconfig_mapping, |
| example_inputs, |
| prepare_custom_config, |
| backend_config=my_backend_config, |
| ) |
| prepared(*example_inputs) |
| converted = convert_fx( |
| prepared, |
| convert_custom_config, |
| backend_config=my_backend_config, |
| ) |
| converted(*example_inputs) |
| |
| # Find the patterns [dq - op - q_to_specific_dtype] in the graph and |
| # verify that qparams and dtypes are set correctly in the quantize ops |
| node_name_to_expected_quantize_args = { |
| "igates": (None, None, torch.quint8), |
| "hgates": (None, None, torch.quint8), |
| "add": (2 ** -11, 2 ** 15, torch.qint32), # gates.add |
| "input_gate": (2 ** -16, 0, torch.qint32), |
| "forget_gate": (2 ** -16, 0, torch.qint32), |
| "cell_gate": (2 ** -15, 2 ** 15, torch.qint32), |
| "output_gate": (2 ** -16, 0, torch.qint32), |
| "mul": (2 ** -11, 0, torch.qint32), # fgate_cx.mul |
| "mul_1": (2 ** -11, 0, torch.qint32), # igate_cgate.mul |
| "add_1": (2 ** -11, 0, torch.qint32), # fgate_cx_igate_cgate.add |
| "mul_2": (2 ** -7, 2 ** 7, torch.quint8), # ogate_cy.mul |
| } |
| cell = converted.my_lstm.layers.get_submodule("0").layer_fw.cell |
| matched_names = set() |
| for node in cell.graph.nodes: |
| if node.name not in node_name_to_expected_quantize_args: |
| continue |
| matched_names.add(node.name) |
| # Match preceding dequantize |
| self.assertTrue(all(arg.target == "dequantize" for arg in node.args)) |
| # Match following quantize with the specific qparams and dtypes |
| expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name] |
| for user in node.users.keys(): |
| self.assertEqual(user.target, torch.quantize_per_tensor) |
| if expected_scale is not None: |
| self.assertEqual(getattr(cell, user.args[1].target), expected_scale) |
| if expected_zp is not None: |
| self.assertEqual(getattr(cell, user.args[2].target), expected_zp) |
| self.assertEqual(user.args[-1], expected_dtype) |
| # Ensure all patterns were matched |
| self.assertEqual(matched_names, set(node_name_to_expected_quantize_args.keys())) |
| |
| def test_reroute_tuple_getitem_patterns(self): |
| """ |
| The following graph should redirect the output to `b`. After the transformation, |
| all other nodes, including the inputs `a` and `c`, are no longer needed. |
| |
| a b c |
| | \\ / |
| \\ tuple |
| \\ / |
| tuple |
| / \\ |
| / \\ |
| | \\ |
| | \\ |
| | \\ |
| getitem0 getitem1 |
| | / \\ |
| | getitem0 getitem1 |
| | \\ / |
| \\ tuple |
| \\ / |
| \\ / |
| tuple |
| | |
| getitem1 |
| | |
| getitem0 |
| | |
| output |
| """ |
| # Construct graph manually because symbolic_trace does not insert tuple and getitem nodes |
| graph = torch.fx.Graph() |
| a = graph.create_node("placeholder", "a") |
| b = graph.create_node("placeholder", "b") |
| c = graph.create_node("placeholder", "c") |
| bc = graph.call_function(tuple, args=([b, c],)) |
| abc = graph.call_function(tuple, args=([a, bc],)) |
| |
| # Break down tuple and reconstruct it again |
| a2 = graph.call_function(operator.getitem, args=(abc, 0)) |
| bc2 = graph.call_function(operator.getitem, args=(abc, 1)) |
| b2 = graph.call_function(operator.getitem, args=(bc2, 0)) |
| c2 = graph.call_function(operator.getitem, args=(bc2, 1)) |
| bc3 = graph.call_function(tuple, args=([b2, c2],)) |
| abc2 = graph.call_function(tuple, args=([a2, bc3],)) |
| |
| # Output tuple[1][0] |
| bc4 = graph.call_function(operator.getitem, args=(abc2, 1)) |
| b3 = graph.call_function(operator.getitem, args=(bc4, 0)) |
| output = graph.output(b3) |
| |
| # Do reroute |
| _reroute_tuple_getitem_pattern(graph) |
| |
| # Assert that output reroutes to `b` directly, and all other nodes can be removed |
| output_ancestors = [] |
| def gather_ancestors(current_node): # noqa: E306 |
| for arg in current_node.args: |
| output_ancestors.append(arg) |
| gather_ancestors(arg) |
| gather_ancestors(output) |
| self.assertEqual(output_ancestors, [b]) |
| self.assertEqual(output.args[0], b) |
| |
| def test_relu_lowering(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| return torch.nn.functional.relu(x) |
| |
| m = M().eval() |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),)) |
| m_copy = copy.deepcopy(m) |
| m = convert_fx(m) |
| m_ref = convert_to_reference_fx(m_copy) |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| node_occurrence_ref = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 2 |
| } |
| |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) |
| |
| @skipIfNoFBGEMM |
| def test_dynamic_with_fusion(self): |
| """ |
| Tests that dynamic quantization APIs work with Linear + Relu fusion |
| """ |
| with override_quantized_engine('fbgemm'): |
| class LinearRelu(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return self.relu(x) |
| |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu()) |
| self.mods2 = Linear() |
| self.relu = F.relu |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = self.mods2(x) |
| x = self.relu(x) |
| return x |
| |
| dynamic_quantized_ops = { |
| float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16, |
| default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic |
| } |
| for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: |
| model = M().eval() |
| qconfig_dict = { |
| "": qconfig |
| } |
| example_inputs = (torch.rand(5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| node_list = [ |
| ns.call_module(nniqd.LinearReLU), |
| ns.call_module(nniqd.LinearReLU), |
| ns.call_function(dynamic_quantized_ops[qconfig]), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| @skipIfNoFBGEMM |
| def test_dynamic_with_fusion_multiple_uses(self): |
| """ |
| Tests that dynamic quantization APIs work with Linear + Relu fusion |
| """ |
| with override_quantized_engine('fbgemm'): |
| class LinearRelu(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return self.relu(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear_relu = LinearRelu() |
| |
| def forward(self, x): |
| x = self.linear_relu(x) |
| x = self.linear_relu(x) |
| return x |
| |
| for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: |
| model = M().eval() |
| qconfig_dict = { |
| "": qconfig |
| } |
| example_inputs = (torch.randn(5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| node_list = [ |
| ns.call_module(nniqd.LinearReLU), |
| ns.call_module(nniqd.LinearReLU), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| @skipIfNoFBGEMM |
| def test_dynamic_linear_input_multiple_use(self): |
| """ |
| Tests input for dynamic linear being used by multiple ops |
| """ |
| with override_quantized_engine('fbgemm'): |
| class LinearRelu(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return self.relu(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod1 = LinearRelu() |
| self.mod2 = LinearRelu() |
| |
| def forward(self, x): |
| y1 = self.mod1(x) |
| y2 = self.mod2(x) |
| return y1 + y2 |
| |
| for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: |
| model = M().eval() |
| qconfig_dict = { |
| "": qconfig |
| } |
| example_inputs = (torch.rand(5, 5, 5),) |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| node_list = [ |
| ns.call_module(nniqd.LinearReLU), |
| ns.call_module(nniqd.LinearReLU), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_ref_linear_module(self): |
| """ Make sure the numerics for models with ref linear module |
| matches models with fbgemm/qnnpack module |
| """ |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(10, 5) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class M2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(10, 5) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(self.linear(x)) |
| |
| for M in [M1, M2]: |
| m = M().eval() |
| example_inputs = (torch.randn(5, 10),) |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| m_copy = copy.deepcopy(m) |
| m = convert_fx(m) |
| m_ref = convert_to_reference_fx(m_copy) |
| result = m(*example_inputs) |
| result_ref = m_ref(*example_inputs) |
| self.assertTrue(torch.equal(result, result_ref)) |
| |
| def test_ref_conv_module(self): |
| """ Make sure the numerics for models with ref conv module |
| matches models with fbgemm/qnnpack module |
| """ |
| convs = { |
| 1: nn.Conv1d, |
| 2: nn.Conv2d, |
| 3: nn.Conv3d, |
| } |
| |
| class M1(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.conv = convs[dim](3, 3, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| class M2(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.conv = convs[dim](3, 3, 3) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(self.conv(x)) |
| |
| for dim, M in itertools.product([1, 2, 3], [M1, M2]): |
| m = M(dim).eval() |
| data = self.img_data_dict[dim][0][0] |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) |
| m_copy = copy.deepcopy(m) |
| m = convert_fx(m) |
| m_ref = convert_to_reference_fx(m_copy) |
| result = m(data) |
| result_ref = m_ref(data) |
| self.assertTrue(torch.equal(result, result_ref)) |
| |
| def test_sub_scalar(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| x = x + 1 |
| x = x - 1 |
| x = x + 3 |
| x = x - 4 |
| return x |
| |
| m = M().eval() |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),)) |
| m = convert_fx(m) |
| occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 2 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence) |
| |
| def test_observer_fqn(self): |
| """ |
| Test to make sure the observer FQN is based on the quantizable op/module that it is observing |
| and uses the modules FQN to determine the observer name. |
| """ |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| self.mods2 = Linear() |
| self.mods3 = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = torch.add(x, 4) |
| x = self.mods2(x) |
| y = torch.add(x, 2) |
| z = torch.mul(x, 5) |
| a = self.mods3(y) |
| return a, z |
| |
| model = M().eval() |
| |
| prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5))) |
| name_list = [] |
| for name, mod in prepared.named_modules(): |
| if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver): |
| name_list.append(name) |
| expected_name_list = ['activation_post_process_0', |
| 'activation_post_process_1', |
| 'activation_post_process_2', |
| 'activation_post_process_3', |
| 'activation_post_process_4', |
| 'activation_post_process_6', |
| 'activation_post_process_7', |
| 'activation_post_process_10'] |
| assert name_list == expected_name_list |
| |
| def test_conv_lowering(self): |
| convs = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} |
| qconvs = {1: nn.quantized.Conv1d, 2: nn.quantized.Conv2d, 3: nn.quantized.Conv3d} |
| |
| class M(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.conv = convs[dim](3, 3, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| for dim in range(1, len(convs) + 1): |
| m = M(dim).eval() |
| data = self.img_data_dict[dim][0][0] |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) |
| m_ref = copy.deepcopy(m) |
| m_ref = convert_to_reference_fx(m_ref) |
| m = convert_fx(m) |
| out_ref = m_ref(data) |
| out = m(data) |
| # check that reference pattern for quantized conv module is fused |
| expected_node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_module(qconvs[dim]): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=expected_node_occurrence) |
| # checking result match |
| self.assertTrue(torch.equal(out_ref, out)) |
| |
| def test_convert_qconfig_mapping(self): |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential( |
| Linear(), |
| Linear() |
| ) |
| self.mods3 = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = torch.add(x, 4) |
| z = torch.mul(x, 5) |
| x = self.mods3(z) |
| return x |
| |
| model = M().train() |
| |
| for check in ["module_name", "object_type"]: |
| qconfig_dict = {"": None, |
| "object_type": [ |
| (nn.functional.linear, get_default_qat_qconfig("fbgemm")), |
| (torch.add, get_default_qat_qconfig("fbgemm")), |
| (nn.Linear, get_default_qat_qconfig("fbgemm")), |
| ], |
| } |
| example_inputs = (torch.rand(5, 5),) |
| prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) |
| prepared(*example_inputs) |
| if check == "module_name": |
| convert_qconfig_dict = {"": None, |
| "object_type": [ |
| (nn.functional.linear, get_default_qat_qconfig("fbgemm")), |
| (torch.add, get_default_qat_qconfig("fbgemm")), |
| (nn.Linear, get_default_qat_qconfig("fbgemm")), |
| ], |
| "module_name": [("mods1.0", None)]} |
| |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_function(torch.nn.functional.linear): 1, |
| ns.call_function(torch.ops.quantized.linear): 1, |
| ns.call_function(torch.ops.quantized.add): 1, |
| ns.call_method("dequantize"): 2 |
| } |
| order_check = [ |
| ns.call_function(torch.nn.functional.linear), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.linear), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_method("dequantize"), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Linear), |
| ns.call_method("dequantize"), |
| ] |
| elif check == "object_type": |
| convert_qconfig_dict = {"": None, |
| "object_type": [ |
| (nn.functional.linear, get_default_qat_qconfig("fbgemm")), |
| (torch.add, get_default_qat_qconfig("fbgemm")), |
| (nn.Linear, None), |
| ]} |
| |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_function(torch.ops.quantized.linear): 2, |
| ns.call_function(torch.ops.quantized.add): 1, |
| ns.call_function(torch.mul): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.linear), |
| ns.call_function(torch.ops.quantized.linear), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_method("dequantize"), |
| ns.call_function(torch.mul), |
| ns.call_module(nn.Linear), |
| ] |
| |
| converted = convert_fx(prepared, qconfig_mapping=convert_qconfig_dict) |
| converted(torch.rand(5, 5)) |
| self.checkGraphModuleNodes( |
| converted, |
| expected_node_occurrence=node_occurrence, |
| expected_node_list=order_check) |
| |
| def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2): |
| self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr) |
| |
| def test_register_patterns(self): |
| def cleanUp(): |
| del _DEFAULT_FUSION_PATTERNS["dummy_fusion"] |
| del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"] |
| del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"] |
| del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"] |
| del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"] |
| del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"] |
| del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"] |
| del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"] |
| self.addCleanup(cleanUp) |
| |
| @_register_fusion_pattern("dummy_fusion") |
| class DummyFusion: |
| pass |
| |
| @_register_quant_pattern("dummy_quant") |
| class DummyQuant: |
| pass |
| |
| @_register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer) |
| class DummyQuant2: |
| pass |
| |
| @_register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer) |
| class DummyQuant3: |
| pass |
| |
| self.assertEqual(_DEFAULT_FUSION_PATTERNS["dummy_fusion"], DummyFusion) |
| self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant) |
| self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2) |
| self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3) |
| self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer) |
| self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer) |
| self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"], |
| default_fixed_qparams_range_0to1_fake_quant) |
| self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"], |
| default_fixed_qparams_range_neg1to1_fake_quant) |
| output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True) |
| output_observer_map = get_default_output_activation_post_process_map(is_training=False) |
| self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer) |
| self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"), |
| default_fixed_qparams_range_neg1to1_fake_quant) |
| |
| |
| |
| def test_reuse_input_qconfig(self): |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = x.reshape() |
| return x |
| |
| class M2(torch.nn.Module): |
| def forward(self, x): |
| x = x.reshape() |
| return x |
| |
| options = itertools.product([M1, M2], [True, False]) |
| for M, is_qat in options: |
| m = M1().eval() |
| example_inputs = (torch.randn(1, 3, 3, 3),) |
| m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs) |
| m = convert_fx(m) |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("reshape"), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=node_list) |
| |
| m = M2().eval() |
| m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs) |
| m = convert_fx(m) |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 0, |
| ns.call_method("dequnatize"): 0, |
| } |
| node_list = [ |
| ns.call_method("reshape"), |
| ] |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=node_occurrence, |
| expected_node_list=node_list) |
| |
| def test_stack_trace_preserved_linear(self): |
| class M(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(1, 1) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return x |
| |
| m = M().eval() |
| mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),)) |
| |
| found_stack_trace = False |
| for n in mp.graph.nodes: |
| if n.op == 'call_module' and n.target == 'linear': |
| found_stack_trace = n.stack_trace is not None |
| break |
| self.assertTrue(found_stack_trace) |
| |
| # test reference model |
| mq = convert_to_reference_fx(copy.deepcopy(mp)) |
| found_stack_trace = False |
| for n in mq.graph.nodes: |
| if n.op == 'call_module' and n.target == 'linear': |
| found_stack_trace = n.stack_trace is not None |
| break |
| self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: True") |
| |
| # test quantized model |
| mq = convert_fx(mp) |
| found_stack_trace = False |
| for n in mq.graph.nodes: |
| if n.op == 'call_module' and n.target == 'linear': |
| found_stack_trace = n.stack_trace is not None |
| break |
| self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: False") |
| |
| def test_qat_skip_untraced(self): |
| class UnTraceableModuleClass(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(2, 2) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class UnTraceableModuleName(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(2, 2) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class M(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.untraceable_module_class = UnTraceableModuleClass() |
| self.untraceable_module_name = UnTraceableModuleClass() |
| |
| def forward(self, x): |
| x = self.untraceable_module_class(x) |
| x = self.untraceable_module_name(x) |
| return x |
| |
| mod = M() |
| |
| qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig()} |
| prepare_custom_config_dict = { |
| "non_traceable_module_class": [UnTraceableModuleClass], |
| "non_traceable_module_name": ["untraceable_module_name"], |
| } |
| example_inputs = (torch.randn(2, 2),) |
| mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( |
| mod.train(), qconfig_dict, example_inputs=example_inputs, |
| prepare_custom_config=prepare_custom_config_dict |
| ) |
| mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( |
| mod.train(), qconfig_dict, example_inputs=example_inputs, |
| prepare_custom_config=prepare_custom_config_dict |
| ) |
| self.assertTrue( |
| isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear) |
| ) |
| self.assertTrue( |
| isinstance(mod_prep.untraceable_module_name.linear, torch.nn.Linear) |
| ) |
| self.assertTrue( |
| type(mod_prep.untraceable_module_class.linear) |
| is not torch.ao.nn.qat.modules.linear.Linear, |
| "prepare_qat_fx shold not convert anything inside untraced module classes", |
| ) |
| self.assertTrue( |
| type(mod_prep.untraceable_module_name.linear) |
| is not torch.ao.nn.qat.modules.linear.Linear, |
| "prepare_qat_fx shold not convert anything inside modules named in untraced_module_names", |
| ) |
| |
| def test_qconfig_dict_setup(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.Conv1d = torch.nn.Conv1d(1, 1, 1) |
| self.Conv2d = torch.nn.Conv2d(1, 1, 1) |
| self.Conv3d = torch.nn.Conv3d(1, 1, 1) |
| self.ConvTranspose1d = torch.nn.ConvTranspose1d(1, 1, 1) |
| self.ConvTranspose2d = torch.nn.ConvTranspose2d(1, 1, 1) |
| self.ConvTranspose3d = torch.nn.ConvTranspose3d(1, 1, 1) |
| self.Linear = torch.nn.Linear(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.Conv1d(x) |
| x = self.Conv2d(x) |
| x = self.Conv3d(x) |
| x = self.ConvTranspose1d(x) |
| x = self.ConvTranspose2d(x) |
| x = self.ConvTranspose3d(x) |
| x = self.Linear(x) |
| x = torch.nn.functional.conv1d(x, torch.rand(2, 2)) |
| x = torch.nn.functional.conv2d(x, torch.rand(2, 2)) |
| x = torch.nn.functional.conv3d(x, torch.rand(2, 2)) |
| x = torch.nn.functional.linear(x, torch.rand(2, 2)) |
| return x |
| |
| backends = ["qnnpack", "fbgemm"] |
| for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]: |
| for backend in backends: |
| m = M().eval() |
| qconfig_dict = func(backend) |
| m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) |
| for name, mod in m.named_modules(): |
| if _is_activation_post_process(mod) and mod.dtype == torch.quint8: |
| if backend == "fbgemm": |
| lower_bnd = 0 |
| upper_bnd = 127 |
| else: |
| lower_bnd = 0 |
| upper_bnd = 255 |
| if issubclass(type(mod), FakeQuantize): |
| self.assertEqual(mod.activation_post_process.quant_min, lower_bnd) |
| self.assertEqual(mod.activation_post_process.quant_max, upper_bnd) |
| else: |
| self.assertEqual(mod.quant_min, lower_bnd) |
| self.assertEqual(mod.quant_max, upper_bnd) |
| |
| def test_prepare_mode(self): |
| class LinearModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| def _test(prepare_fn, qconfig_dict): |
| m = LinearModel() |
| m1 = copy.deepcopy(m) |
| m1.train() |
| example_inputs = (torch.randn(1, 5),) |
| prepare_fn(m1, qconfig_dict, example_inputs=example_inputs) |
| m2 = copy.deepcopy(m) |
| m2.eval() |
| prepare_fn(m2, qconfig_dict, example_inputs=example_inputs) |
| |
| # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes |
| _test(prepare_fx, get_default_qconfig_mapping()) |
| _test(prepare_qat_fx, get_default_qat_qconfig_mapping()) |
| |
| def _validate_qconfig_against_backend_config_constraints( |
| self, |
| model: torch.nn.Module, |
| qconfig: QConfig, |
| backend_config: BackendConfig, |
| satisfies_constraints: bool, |
| qconfig_name: Optional[str] = None): |
| """ |
| Helper method to validate whether `qconfig` satisfies the constraints specified in `backend_config`. |
| """ |
| qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) |
| example_inputs = (torch.rand((1, 30), dtype=torch.float),) |
| model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) |
| model(*example_inputs) |
| model = convert_fx(model, backend_config=backend_config) |
| if satisfies_constraints: |
| expected_node_occurrence = { |
| ns.call_module(torch.ao.nn.quantized.Linear) : 1, |
| ns.call_module(torch.nn.Linear) : 0, |
| } |
| else: |
| expected_node_occurrence = { |
| ns.call_module(torch.ao.nn.quantized.Linear) : 0, |
| ns.call_module(torch.nn.Linear) : 1, |
| } |
| try: |
| self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) |
| except AssertionError as e: |
| if qconfig_name is not None: |
| print("ERROR: Validation for QConfig '%s' failed" % qconfig_name) |
| raise e |
| |
| def test_backend_config_quantization_range(self): |
| """ |
| Check that quantization ranges specified through the BackendConfig are reflected in |
| the observers inserted into the model. |
| """ |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| dtype_config = DTypeConfig( |
| input_dtype=DTypeWithConstraints( |
| dtype=torch.quint8, |
| quant_min_lower_bound=0, |
| quant_max_upper_bound=31, |
| ), |
| output_dtype=DTypeWithConstraints( |
| dtype=torch.quint8, |
| quant_min_lower_bound=0, |
| quant_max_upper_bound=31, |
| ), |
| weight_dtype=DTypeWithConstraints( |
| dtype=torch.qint8, |
| quant_min_lower_bound=-64, |
| quant_max_upper_bound=63, |
| ), |
| bias_dtype=torch.float, |
| ) |
| backend_config = BackendConfig() \ |
| .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) |
| .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 |
| .add_dtype_config(dtype_config) |
| .set_root_module(torch.nn.Linear) |
| .set_reference_quantized_module(nnqr.Linear)) |
| |
| def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool): |
| self._validate_qconfig_against_backend_config_constraints( |
| MyModel(), qconfig, backend_config, satisfies_constraints) |
| |
| # Case 1: QConfig ranges fit within backend ranges, OK |
| qconfig1 = QConfig( |
| activation=MinMaxObserver.with_args(quant_min=0, quant_max=15, dtype=torch.quint8), |
| weight=MinMaxObserver.with_args(quant_min=-32, quant_max=31, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) |
| validate_qconfig(qconfig1, satisfies_constraints=True) |
| |
| # Case 2: QConfig activation range falls outside backend range, should fail |
| qconfig2 = QConfig( |
| activation=MinMaxObserver.with_args(quant_min=0, quant_max=63, dtype=torch.quint8), |
| weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) |
| validate_qconfig(qconfig2, satisfies_constraints=False) |
| |
| # Case 3: QConfig weight range falls outside backend range, should fail |
| qconfig3 = QConfig( |
| activation=MinMaxObserver.with_args(dtype=torch.quint8), |
| weight=MinMaxObserver.with_args(quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) |
| validate_qconfig(qconfig3, satisfies_constraints=False) |
| |
| # Case 4: QConfig doesn't specify range, should fail |
| qconfig4 = QConfig(activation=ReuseInputObserver, weight=ReuseInputObserver) |
| validate_qconfig(qconfig4, satisfies_constraints=False) |
| |
| def test_backend_config_scale_min(self): |
| """ |
| Test QConfig eps validation against the BackendConfig's min scale value. |
| """ |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| dtype_config = DTypeConfig( |
| input_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12), |
| output_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12), |
| weight_dtype=DTypeWithConstraints(dtype=torch.qint8, scale_min_lower_bound=2 ** -12), |
| bias_dtype=torch.float, |
| ) |
| |
| backend_config = BackendConfig() \ |
| .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) |
| .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 |
| .add_dtype_config(dtype_config) |
| .set_root_module(torch.nn.Linear) |
| .set_reference_quantized_module(nnqr.Linear)) |
| |
| def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool): |
| self._validate_qconfig_against_backend_config_constraints( |
| MyModel(), qconfig, backend_config, satisfies_constraints) |
| |
| # Case 1: QConfig min scale value == backend min scale value, OK |
| qconfig1 = QConfig( |
| activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -12), |
| weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -12)) |
| validate_qconfig(qconfig1, satisfies_constraints=True) |
| |
| # Case 2: QConfig min scale value > backend min scale value, OK |
| qconfig2 = QConfig( |
| activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -10), |
| weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -10)) |
| validate_qconfig(qconfig2, satisfies_constraints=True) |
| |
| # Case 3: QConfig activation min scale value < backend min scale value, should fail |
| qconfig3 = QConfig( |
| activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -14), |
| weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) |
| validate_qconfig(qconfig3, satisfies_constraints=False) |
| |
| # Case 3: QConfig weight min scale value < backend min scale value, should fail |
| qconfig4 = QConfig( |
| activation=MinMaxObserver.with_args(dtype=torch.quint8), |
| weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -14)) |
| validate_qconfig(qconfig4, satisfies_constraints=False) |
| |
| # Case 5: QConfig doesn't specify eps, should fail |
| qconfig5 = QConfig( |
| activation=FixedQParamsObserver.with_args(scale=1.0, zero_point=0), |
| weight=FixedQParamsObserver.with_args(scale=1.0, zero_point=0)) |
| validate_qconfig(qconfig5, satisfies_constraints=False) |
| |
| def test_qnnpack_backend_config(self): |
| """ |
| Test whether default QNNPACK QConfigs are compatible with the QNNPACK BackendConfig. |
| """ |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| all_qconfigs: List[Tuple[QConfig, str]] = [ |
| (get_default_qconfig("qnnpack", version=0), "default_qnnpack_qconfig_v0"), |
| (get_default_qat_qconfig("qnnpack", version=0), "default_qat_qnnpack_qconfig_v0"), |
| (get_default_qat_qconfig("qnnpack", version=1), "default_qat_qnnpack_qconfig_v1"), |
| (default_symmetric_qnnpack_qconfig, "default_symmetric_qnnpack_qconfig"), |
| (default_symmetric_qnnpack_qat_qconfig, "default_symmetric_qnnpack_qat_qconfig"), |
| # TODO: Test these QConfigs once they are fixed, see https://github.com/pytorch/pytorch/issues/85862 |
| # (default_per_channel_symmetric_qnnpack_qconfig, "default_per_channel_symmetric_qnnpack_qconfig"), |
| # (default_per_channel_symmetric_qnnpack_qat_qconfig, "default_per_channel_symmetric_qnnpack_qat_qconfig"), |
| ] |
| backend_config = get_qnnpack_backend_config() |
| for qconfig, qconfig_name in all_qconfigs: |
| self._validate_qconfig_against_backend_config_constraints( |
| MyModel(), qconfig, backend_config, satisfies_constraints=True, qconfig_name=qconfig_name) |
| |
| def test_symmetric_qnnpack_qconfig_mapping(self): |
| """ |
| Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qconfig_mapping` |
| works with the QNNPACK BackendConfig. |
| """ |
| if "qnnpack" not in supported_qengines: |
| return |
| |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| with override_quantized_engine("qnnpack"): |
| qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping() |
| example_inputs = (torch.rand((1, 30), dtype=torch.float),) |
| backend_config = get_qnnpack_backend_config() |
| model = MyModel() |
| model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) |
| model(*example_inputs) |
| model = convert_fx(model, backend_config=backend_config) |
| expected_node_occurrence = { |
| ns.call_module(torch.ao.nn.quantized.Linear) : 1, |
| ns.call_module(torch.nn.Linear) : 0, |
| } |
| self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) |
| model(*example_inputs) |
| |
| def test_symmetric_qnnpack_qat_qconfig_mapping(self): |
| """ |
| Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qat_qconfig_mapping` |
| works with the QNNPACK BackendConfig. |
| """ |
| if "qnnpack" not in supported_qengines: |
| return |
| |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| with override_quantized_engine("qnnpack"): |
| qconfig_mapping = _get_symmetric_qnnpack_qat_qconfig_mapping() |
| example_inputs = (torch.rand((1, 30), dtype=torch.float),) |
| backend_config = get_qnnpack_backend_config() |
| model = MyModel() |
| model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) |
| model(*example_inputs) |
| model = convert_fx(model, backend_config=backend_config) |
| expected_node_occurrence = { |
| ns.call_module(torch.ao.nn.quantized.Linear) : 1, |
| ns.call_module(torch.nn.Linear) : 0, |
| } |
| self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) |
| model(*example_inputs) |
| |
| |
| def test_get_executorch_backend_config(self): |
| from torch.ao.quantization.backend_config import get_executorch_backend_config |
| # make sure this runs |
| executorch_backend_config = get_executorch_backend_config() |
| |
| def test_backend_config_check_for_weight_and_bias(self): |
| """ Test to make sure the backend_config check for weight and bias |
| runs when the qconfig is None for the ops with weight and bias |
| previously the error was not hit because we first check input, and |
| the check for weight and bias are skipped. |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.tensor((5, 5)) |
| self.bias = torch.tensor((5,)) |
| |
| def forward(self, x): |
| return torch.addmm(self.bias, x, self.weight) |
| |
| m = M().eval() |
| qconfig_mapping = QConfigMapping() |
| observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT |
| weighted_op_quint8_dtype_config = DTypeConfig( |
| input_dtype=torch.quint8, |
| output_dtype=torch.quint8, |
| weight_dtype=torch.qint8, |
| bias_dtype=torch.float, |
| ) |
| dtype_configs = [weighted_op_quint8_dtype_config] |
| backend_pattern_config = BackendPatternConfig(torch.addmm) \ |
| .set_observation_type(observation_type) \ |
| .set_dtype_configs(dtype_configs) \ |
| ._set_input_type_to_index({"weight": 2, "bias": 0}) |
| backend_config = BackendConfig() \ |
| .set_backend_pattern_config(backend_pattern_config) |
| example_inputs = (torch.rand(1, 5),) |
| # make sure this runs |
| m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) |
| |
| def test_get_default_qconfig_valid_backend(self): |
| """ Checks that AssertionError is raised when non expected backend input is specified |
| """ |
| invalid_backends = ["imaginary_backend", 3] |
| for invalid_backend in invalid_backends: |
| with self.assertRaisesRegex(AssertionError, "not supported"): |
| qconfig = get_default_qconfig(invalid_backend) |
| with self.assertRaisesRegex(AssertionError, "not supported"): |
| qconfig = get_default_qat_qconfig(invalid_backend) |
| with self.assertRaisesRegex(AssertionError, "not supported"): |
| qconfig_mapping = get_default_qconfig_mapping(invalid_backend) |
| with self.assertRaisesRegex(AssertionError, "not supported"): |
| qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend) |
| |
| def test__convert_to_reference_decomposed_fx(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| m = M().eval() |
| qconfig_mapping = get_default_qconfig_mapping("fbgemm") |
| example_inputs = (torch.randn(1, 5),) |
| m = prepare_fx(m, qconfig_mapping, example_inputs) |
| m_ref = copy.deepcopy(m) |
| m_ref = convert_to_reference_fx(m_ref) |
| m = _convert_to_reference_decomposed_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence) |
| # make sure it runs |
| res_ref = m_ref(*example_inputs) |
| res = m(*example_inputs) |
| self.assertEqual(res, res_ref) |
| |
| @skipIfNoQNNPACK |
| def test__convert_to_reference_decomposed_fx_dynamic_quant(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| # to avoid reduce_range |
| with override_quantized_engine("qnnpack"): |
| m = M().eval() |
| qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ |
| .set_object_type(torch.nn.Linear, default_dynamic_qconfig) |
| example_inputs = (torch.randn(1, 5),) |
| m = prepare_fx(m, qconfig_mapping, example_inputs) |
| m(*example_inputs) |
| m_ref = copy.deepcopy(m) |
| m_ref = convert_to_reference_fx(m_ref) |
| m = _convert_to_reference_decomposed_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1, |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence) |
| # make sure it runs |
| res_ref = m_ref(*example_inputs) |
| res = m(*example_inputs) |
| self.assertEqual(res, res_ref) |
| |
| def test__convert_to_reference_decomposed_fx_per_channel_quant(self): |
| class M(torch.nn.Module): |
| def forward(self, x, weight, bias): |
| return F.linear(x, weight, bias) |
| |
| m = M().eval() |
| qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ |
| .set_object_type(F.linear, default_per_channel_qconfig) |
| example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,)) |
| m = prepare_fx(m, qconfig_mapping, example_inputs) |
| m(*example_inputs) |
| m_ref = copy.deepcopy(m) |
| m_ref = convert_to_reference_fx(m_ref) |
| m = _convert_to_reference_decomposed_fx(m) |
| expected_occurrence = { |
| # for input and output activations |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, |
| # for weight |
| ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1, |
| ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence) |
| # make sure it runs |
| res_ref = m_ref(*example_inputs) |
| res = m(*example_inputs) |
| self.assertEqual(res, res_ref) |
| |
| def test_change_backend_config_for_fixed_qparam_ops(self): |
| """ Making sure we can skip validation of qconfigs for fixedqparam ops based |
| on BackendConfig |
| """ |
| class M(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tanh = torch.nn.Tanh() |
| |
| def forward(self, x: torch.Tensor): |
| x = self.tanh(x) |
| return x |
| |
| model = M().eval() |
| # we set a global default_qconfig, which will be ignored since the backend |
| # we defined doesn't support anything |
| # this is to make sure we don't validate the qconfig when BackendConfig does not |
| # have fixed qparam op related configurations |
| qconfig_mapping = QConfigMapping().set_global(default_qconfig) |
| backend_config = BackendConfig() |
| # make sure this runs |
| model = prepare_fx( |
| model, |
| qconfig_mapping=qconfig_mapping, |
| example_inputs=(torch.randn(1, 2, 3, 4),), |
| backend_config=backend_config |
| ) |
| |
| def test_channel_shuffle_lowering(self): |
| # Three versions of channel shuffle |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.op = torch.nn.ChannelShuffle(2) |
| |
| def forward(self, x): |
| return self.op(x + x) + x |
| |
| class M2(torch.nn.Module): |
| def forward(self, x): |
| return torch.channel_shuffle(x + x, 2) + x |
| |
| class M3(torch.nn.Module): |
| def forward(self, x): |
| return torch.nn.functional.channel_shuffle(x + x, 2) + x |
| |
| x = torch.randn(4, 4, 4, 4) |
| # torch.channel_shuffle is equivalent to torch.nn.functional.channel_shuffle |
| model_node_pairs = [ |
| (M1().eval(), ns.call_module(torch.nn.ChannelShuffle)), |
| (M2().eval(), ns.call_function(torch.channel_shuffle)), |
| (M3().eval(), ns.call_function(torch.channel_shuffle)) |
| ] |
| for m, node in model_node_pairs: |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=(x,)) |
| m_copy = copy.deepcopy(m) |
| m = convert_fx(m) |
| m_ref = convert_to_reference_fx(m_copy) |
| node_occurrence = { |
| node: 1, |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| node_occurrence_ref = { |
| node: 1, |
| ns.call_function(torch.quantize_per_tensor): 4, |
| ns.call_method("dequantize"): 4 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) |
| |
| def test_match_pattern_with_multiple_args(self): |
| """ Test that we can match a pattern that has multiple arguments |
| Pattern: |
| shape \ |
| transpose (observed) -> reshape -> output (observed) -> |
| |
| where `reshape` has two arguments |
| """ |
| |
| def _get_pattern_configs(): |
| backend_pattern_configs = [] |
| observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT |
| weighted_op_quint8_dtype_config = DTypeConfig( |
| input_dtype=torch.quint8, |
| output_dtype=torch.quint8, |
| weight_dtype=torch.qint8, |
| bias_dtype=torch.float, |
| ) |
| dtype_configs = [weighted_op_quint8_dtype_config] |
| |
| def root_node_getter(node_pattern): |
| reshape, transpose, shape = node_pattern |
| return transpose |
| |
| backend_pattern_configs.append( |
| BackendPatternConfig() |
| ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) # noqa: E131 |
| .set_observation_type(observation_type) |
| .set_dtype_configs(dtype_configs) |
| ._set_root_node_getter(root_node_getter) |
| ) |
| return backend_pattern_configs |
| |
| backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs()) |
| |
| class M(torch.nn.Module): |
| def forward(self, x): |
| x = torch.transpose(x, 0, 1) |
| x = torch.reshape(x, (-1,)) |
| return x |
| |
| m = M().eval() |
| qconfig_mapping = QConfigMapping().set_global(default_qconfig) |
| example_inputs = (torch.randn(1, 3, 3, 3),) |
| m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) |
| node_occurrence = { |
| # one for input of the pattern and one for output of the pattern |
| ns.call_module(MinMaxObserver): 2 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def _test_linear_activation_fusion_lowering_helper( |
| self, module, example_inputs, qconfig_mapping, |
| backend_config, fused_module, root_module, activation_module): |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 1, |
| ns.call_module(fused_module): 1, |
| ns.call_module(root_module): 0, |
| ns.call_module(activation_module): 0, |
| } |
| node_occurrence_ref = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 2, |
| } |
| m = module.eval() |
| m = prepare_fx(m, qconfig_mapping, |
| example_inputs=example_inputs, |
| backend_config=backend_config) |
| m_copy = copy.deepcopy(m) |
| m = convert_fx(m, backend_config=backend_config) |
| m_ref = convert_to_reference_fx(m_copy) |
| |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) |
| m(*example_inputs) |
| |
| @skipIfNoONEDNN |
| def test_linear_leaky_relu_lowering(self): |
| """ Test fusion and lowering of Linear - (bn -) LeakyReLU |
| by FX. For onednn backedn only. |
| """ |
| from torch.ao.quantization.backend_config import get_onednn_backend_config |
| qconfig_mapping = get_default_qconfig_mapping('onednn') |
| with override_quantized_engine('onednn'): |
| for with_bn in [True, False]: |
| m = LinearBnLeakyReluModel(with_bn) |
| self._test_linear_activation_fusion_lowering_helper( |
| m, |
| m.get_example_inputs(), |
| qconfig_mapping, |
| get_onednn_backend_config(), |
| nniq.LinearLeakyReLU, |
| nn.Linear, |
| nn.LeakyReLU) |
| |
| @skipIfNoONEDNN |
| def test_linear_tanh_lowering(self): |
| """ Test fusion and lowering of Linear - Tanh |
| by FX. For onednn backedn only. |
| """ |
| from torch.ao.quantization.backend_config import get_onednn_backend_config |
| qconfig_mapping = get_default_qconfig_mapping('onednn') |
| # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. |
| # Need to be able to support fusion of ops with different qconfigs |
| # Since tanh must have 'fixed_qparams_qconfig' while linear should use |
| # the global qconfig, we need to set qconfigs for them manually here for |
| # fusion and cannot put such configs in onednn's default qconfig_mapping. |
| # Known issue: |
| # Cannot fuse linear - tanh and quantize standalone tanh at the same time. |
| qconfig = get_default_qconfig('onednn') |
| qconfig_mapping.set_object_type(torch.nn.Linear, qconfig) |
| qconfig_mapping.set_object_type(torch.nn.Tanh, qconfig) |
| with override_quantized_engine('onednn'): |
| m = LinearTanhModel() |
| self._test_linear_activation_fusion_lowering_helper( |
| m, |
| m.get_example_inputs(), |
| qconfig_mapping, |
| get_onednn_backend_config(), |
| nniq.LinearTanh, |
| nn.Linear, |
| nn.Tanh) |
| |
| @override_qengines |
| def test_linear_size_view(self): |
| class M(torch.nn.Module): |
| def __init__(self, use_relu=False): |
| super().__init__() |
| self.linear = torch.nn.Linear(16, 32) |
| self.relu = torch.nn.ReLU() |
| self.use_relu = use_relu |
| |
| def forward(self, x): |
| x = self.linear(x) |
| if self.use_relu: |
| x = self.relu(x) |
| return x.view(x.size(0), 1, 4, 8) |
| |
| for use_relu in [False, True]: |
| model_fp32 = M(use_relu).eval() |
| qengine = torch.backends.quantized.engine |
| qconfig_mapping = get_default_qconfig_mapping(qengine) |
| x = torch.randn((5, 16)) |
| model_fp32(x) |
| prepared_model = prepare_fx(model_fp32, qconfig_mapping, x) |
| prepared_model(x) |
| quantized_model = convert_fx(prepared_model) |
| node_occurrence = { |
| ns.call_module(nnq.Linear): 0 if use_relu else 1, |
| ns.call_module(nniq.LinearReLU): 1 if use_relu else 0, |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) |
| |
| @override_qengines |
| def test_linear_shape_view(self): |
| class M(torch.nn.Module): |
| def __init__(self, use_relu=False): |
| super().__init__() |
| self.linear = torch.nn.Linear(16, 32) |
| self.relu = torch.nn.ReLU() |
| self.use_relu = use_relu |
| |
| def forward(self, x): |
| x = self.linear(x) |
| if self.use_relu: |
| x = self.relu(x) |
| return x.view(x.shape[0], 1, 4, 8) |
| |
| for use_relu in [False, True]: |
| model_fp32 = M(use_relu).eval() |
| qengine = torch.backends.quantized.engine |
| qconfig_mapping = get_default_qconfig_mapping(qengine) |
| x = torch.randn((5, 16)) |
| model_fp32(x) |
| prepared_model = prepare_fx(model_fp32, qconfig_mapping, x) |
| prepared_model(x) |
| quantized_model = convert_fx(prepared_model) |
| node_occurrence = { |
| ns.call_module(nnq.Linear): 0 if use_relu else 1, |
| ns.call_module(nniq.LinearReLU): 1 if use_relu else 0, |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 1 |
| } |
| self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) |
| |
| def test_mixed_dtypes(self): |
| """ |
| Test that multiple dtypes can be used in the same model for different layers, |
| and the dtypes will be converted correctly between the layers. |
| """ |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(5, 5) |
| self.linear2 = torch.nn.Linear(5, 5) |
| self.sigmoid = torch.nn.Sigmoid() |
| self.tanh = torch.nn.Tanh() |
| self.float_functional = torch.ao.nn.quantized.FloatFunctional() |
| |
| def forward(self, x: torch.Tensor): |
| x = self.linear1(x) # qint32 |
| x = self.linear2(x) # quint8 |
| linear2 = x |
| x = self.sigmoid(x) # back to qint32 |
| x = self.tanh(x) # back to quint8 |
| x = self.float_functional.add(linear2, x) # adding two quint8's together |
| return x |
| |
| def make_qconfig(scale, zp, dtype): |
| return QConfig( |
| activation=FixedQParamsObserver.with_args(scale=scale, zero_point=zp, dtype=dtype), |
| weight=torch.ao.quantization.default_weight_observer) |
| |
| # Set up a QConfigMapping that specifies different qparams and dtypes for different layers |
| qconfig_mapping = QConfigMapping() \ |
| .set_global(get_default_qconfig("qnnpack")) \ |
| .set_module_name("linear1", make_qconfig(1234, 11, torch.qint32)) \ |
| .set_module_name("linear2", make_qconfig(2345, 22, torch.quint8)) \ |
| .set_object_type(torch.nn.Sigmoid, make_qconfig(3456, 33, torch.qint32)) \ |
| .set_object_type(torch.nn.Tanh, make_qconfig(4567, 44, torch.quint8)) |
| |
| # Set up BackendConfig that supports the dtypes configured in the above QConfigMapping |
| weighted_op_qint32_dtype_config = DTypeConfig( |
| input_dtype=torch.qint32, |
| output_dtype=torch.qint32, |
| weight_dtype=torch.qint8, |
| bias_dtype=torch.float, |
| ) |
| fixed_qparams_op_quint8_dtype_config = DTypeConfig( |
| input_dtype=torch.quint8, |
| output_dtype=torch.quint8, |
| ) |
| fixed_qparams_op_qint32_dtype_config = DTypeConfig( |
| input_dtype=torch.qint32, |
| output_dtype=torch.qint32, |
| ) |
| backend_config = get_qnnpack_backend_config() |
| for config in backend_config.configs: |
| if config.pattern == torch.nn.Linear: |
| config.add_dtype_config(weighted_op_qint32_dtype_config) |
| elif config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh]: |
| config.add_dtype_config(fixed_qparams_op_quint8_dtype_config) |
| config.add_dtype_config(fixed_qparams_op_qint32_dtype_config) |
| |
| # Produce the reference quantized model |
| m = MyModule() |
| example_inputs = (torch.rand(5, 5),) |
| prepared = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) |
| prepared(*example_inputs) # calibrate |
| converted = convert_to_reference_fx(prepared, backend_config=backend_config) |
| converted(*example_inputs) |
| |
| # Verify that the reference model is correct |
| # |
| # Reference model until add should be: |
| # fp32_input -> q_to_int32 -> [dq -> linear1_fp32 -> q_to_int32] -> dq -> |
| # q_to_uint8 -> [dq -> linear2_fp32 -> q_to_uint8] -> dq (linear2_dq) -> |
| # q_to_int32 -> [dq -> sigmoid_fp32 -> q_to_int32] -> dq -> |
| # q_to_uint8 -> [dq -> tanh_fp32 -> q_to_uint8] -> dq (tanh_dq) |
| # |
| # Complete reference model with add should be: |
| # [(linear2_dq, tanh_dq) -> add_fp32 -> q_to_uint8] -> dq -> fp32_output |
| |
| target_to_expected_dtypes = { |
| "linear1": torch.qint32, |
| "linear2": torch.quint8, |
| "sigmoid": torch.qint32, |
| "tanh": torch.quint8, |
| torch.add: torch.quint8, |
| } |
| # Find the patterns [dq - op_fp32 - q_to_specific_dtype] in the graph |
| linear2_node = tanh_node = None |
| for node in converted.graph.nodes: |
| if node.target not in target_to_expected_dtypes: |
| continue |
| |
| # Match preceding dequantize |
| self.assertTrue(len(node.args) == 1 or len(node.args) == 2) |
| self.assertTrue(all(arg.target == "dequantize" for arg in node.args)) |
| |
| # Match following quantize with the specific dtypes |
| self.assertEqual(len(node.users), 1) |
| user = list(node.users.keys())[0] |
| self.assertEqual(user.target, torch.quantize_per_tensor) |
| self.assertEqual(user.args[-1], target_to_expected_dtypes[node.target]) |
| |
| # Match [dq - torch.add(linear2_dq, tanh_dq) - q] |
| if node.target == "linear2": |
| linear2_node = node |
| elif node.target == "tanh": |
| tanh_node = node |
| elif node.target == torch.add: |
| linear2_dq, tanh_dq = node.args |
| self.assertEqual(tanh_dq.args[0].args[0], tanh_node) |
| self.assertEqual(linear2_dq.args[0].args[0], linear2_node) |
| |
| def test_lowering_functional_conv_with_kwargs(self): |
| dim_to_op = { |
| 1: F.conv1d, |
| 2: F.conv2d, |
| 3: F.conv3d, |
| } |
| dim_to_qop = { |
| 1: torch.ops.quantized.conv1d, |
| 2: torch.ops.quantized.conv2d, |
| 3: torch.ops.quantized.conv3d, |
| } |
| |
| class Mod(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, dimension): |
| super().__init__() |
| self.dim = dimension |
| self.op = dim_to_op[dimension] |
| kernel_sizes = [kernel_size] * self.dim |
| self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_sizes)) |
| |
| def forward(self, input): |
| return self.op(input, self.weight, bias=None, stride=[1] * self.dim, |
| padding=[0] * self.dim, dilation=[1] * self.dim, groups=1) |
| |
| for dimension in [1, 2, 3]: |
| model = Mod(3, 16, 3, dimension) |
| model.eval() |
| qconfig_mapping = get_default_qconfig_mapping() |
| input_shape = (1, 3, *([8] * dimension)) |
| example_inputs = torch.randn(input_shape) |
| prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) |
| prepared_model(example_inputs) |
| quantized_model = convert_fx(prepared_model) |
| # This should pass |
| quantized_model(example_inputs) |
| # Ensure the quantized model has the expected op |
| node_occurrence = { |
| ns.call_function(dim_to_qop[dimension]): 1, |
| } |
| self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) |
| |
| def test_lowering_functional_conv_transpose_with_kwargs(self): |
| dim_to_op = { |
| 1: F.conv_transpose1d, |
| 2: F.conv_transpose2d, |
| 3: F.conv_transpose3d, |
| } |
| dim_to_qop = { |
| 1: torch.ops.quantized.conv_transpose1d, |
| 2: torch.ops.quantized.conv_transpose2d, |
| 3: torch.ops.quantized.conv_transpose3d, |
| } |
| |
| class Mod(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, dimension): |
| super().__init__() |
| self.dim = dimension |
| self.op = dim_to_op[dimension] |
| kernel_sizes = [kernel_size] * self.dim |
| self.weight = nn.Parameter(torch.randn(in_channels, out_channels, *kernel_sizes)) |
| |
| def forward(self, input): |
| return self.op(input, self.weight, bias=None, stride=[1] * self.dim, |
| padding=[0] * self.dim, output_padding=[0] * self.dim, |
| dilation=[1] * self.dim, groups=1) |
| |
| for dimension in [1, 2, 3]: |
| model = Mod(3, 16, 3, dimension) |
| model.eval() |
| qconfig_mapping = get_default_qconfig_mapping() |
| input_shape = (1, 3, *([8] * dimension)) |
| example_inputs = torch.randn(input_shape) |
| prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) |
| prepared_model(example_inputs) |
| quantized_model = convert_fx(prepared_model) |
| # This should pass |
| quantized_model(example_inputs) |
| # Ensure the quantized model has the expected op |
| node_occurrence = { |
| ns.call_function(dim_to_qop[dimension]): 1, |
| } |
| self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) |
| |
| def test_lowering_functional_linear_with_kwargs(self): |
| class Mod(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super().__init__() |
| self.weight = nn.Parameter(torch.randn(out_channels, in_channels)) |
| |
| def forward(self, input): |
| return F.linear(input, self.weight, bias=None) |
| |
| model = Mod(8, 4) |
| model.eval() |
| qconfig_mapping = get_default_qconfig_mapping() |
| example_inputs = torch.randn(1, 8) |
| prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) |
| prepared_model(example_inputs) |
| quantized_model = convert_fx(prepared_model) |
| # This should pass |
| quantized_model(example_inputs) |
| # Ensure the quantized model has the expected op |
| node_occurrence = { |
| ns.call_function(torch.ops.quantized.linear): 1, |
| } |
| self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) |
| |
| @skipIfNoFBGEMM |
| class TestQuantizeFxOps(QuantizationTestCase): |
| def setUp(self): |
| super().setUp() |
| self.custom_qconfig = torch.ao.quantization.QConfig( |
| activation=torch.ao.quantization.observer.HistogramObserver.with_args( |
| qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 |
| ), |
| weight=torch.ao.quantization.default_per_channel_weight_observer |
| ) |
| self.common_quant_patterns = { |
| torch.nn.ConvTranspose1d: DefaultNodeQuantizeHandler, |
| torch.nn.ConvTranspose2d: DefaultNodeQuantizeHandler, |
| torch.nn.ELU: DefaultNodeQuantizeHandler, |
| torch.nn.LeakyReLU: DefaultNodeQuantizeHandler, |
| torch.nn.Hardswish: DefaultNodeQuantizeHandler, |
| torch.nn.InstanceNorm1d: DefaultNodeQuantizeHandler, |
| torch.nn.InstanceNorm2d: DefaultNodeQuantizeHandler, |
| torch.nn.InstanceNorm3d: DefaultNodeQuantizeHandler, |
| torch.nn.LayerNorm: DefaultNodeQuantizeHandler, |
| torch.nn.SiLU: DefaultNodeQuantizeHandler, |
| torch.nn.Mish: DefaultNodeQuantizeHandler, |
| torch.nn.GELU: DefaultNodeQuantizeHandler, |
| torch.nn.Softmax: DefaultNodeQuantizeHandler, |
| torch.nn.functional.elu: DefaultNodeQuantizeHandler, |
| torch.nn.functional.hardswish: DefaultNodeQuantizeHandler, |
| torch.nn.functional.instance_norm: DefaultNodeQuantizeHandler, |
| torch.nn.functional.layer_norm: DefaultNodeQuantizeHandler, |
| torch.nn.functional.leaky_relu: DefaultNodeQuantizeHandler, |
| torch.nn.functional.silu: DefaultNodeQuantizeHandler, |
| torch.nn.functional.mish: DefaultNodeQuantizeHandler, |
| torch.nn.functional.gelu: DefaultNodeQuantizeHandler, |
| torch.nn.functional.softmax: DefaultNodeQuantizeHandler, |
| torch.sum: DefaultNodeQuantizeHandler |
| } |
| |
| """Unit tests for individual ops |
| """ |
| @skipIfNoFBGEMM |
| def test_linear_module(self): |
| with override_quantized_engine('fbgemm'): |
| class LinearModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class LinearReLUModel(torch.nn.Module): |
| def __init__(self, f_relu=False): |
| super().__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.relu(x) |
| return x |
| |
| class LinearBnModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4).float() |
| self.bn = torch.nn.BatchNorm1d(4) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.bn(x) |
| return x |
| |
| # Test linear |
| data = (torch.rand((1, 30), dtype=torch.float),) |
| for quant_type in self.all_quant_types: |
| model = LinearModel() |
| quantized_module = nnqd.Linear if quant_type == QuantType.DYNAMIC else nnq.Linear |
| quantized_node = ns.call_module(quantized_module) |
| result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) |
| if quant_type in self.static_quant_types: |
| self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) |
| |
| # TODO: enable test for dynamic quant |
| # Test linear-relu |
| for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]): |
| model = LinearReLUModel(f_relu) |
| quantized_node = ns.call_module(nniq.LinearReLU) |
| result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) |
| self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) |
| |
| # Test linear-bn |
| data = (torch.rand((4, 4), dtype=torch.float),) |
| for quant_type in self.static_quant_types: |
| model = LinearBnModel() |
| quantized_node = ns.call_module(nnq.Linear) |
| result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) |
| self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) |
| |
| @skipIfNoFBGEMM |
| def test_functional_linear(self): |
| with override_quantized_engine('fbgemm'): |
| class FuncLinear(torch.nn.Module): |
| def __init__(self, use_bias, has_relu, f_relu): |
| super().__init__() |
| self.w = torch.randn(4, 30) |
| self.b = torch.randn(4) |
| self.use_bias = use_bias |
| if has_relu: |
| if f_relu: |
| self.relu_or_id = F.relu |
| else: |
| self.relu_or_id = torch.nn.ReLU() |
| else: |
| self.relu_or_id = torch.nn.Identity() |
| |
| def forward(self, x): |
| if self.use_bias: |
| x = F.linear(x, self.w, self.b) |
| else: |
| x = F.linear(x, self.w) |
| x = self.relu_or_id(x) |
| return x |
| |
| data = (torch.rand((1, 30), dtype=torch.float),) |
| quant_type_to_qlinear_fun = { |
| QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), |
| QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), |
| QuantType.QAT: ns.call_function(torch.ops.quantized.linear), |
| } |
| quant_type_to_qlinear_relu_fun = { |
| # we don't have linear_relu_dynamic |
| QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic), |
| QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu), |
| QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu), |
| } |
| |
| options = itertools.product( |
| self.all_quant_types, |
| (True, False), # use_bias |
| (True, False), # has_relu |
| (True, False), # functional relu |
| ) |
| for quant_type, use_bias, has_relu, f_relu in options: |
| # when has_relu is False, we are using an nn.Identity and |
| # we will insert observer/fake_quant for the output of nn.Identity since |
| # it is a copy node, that's why we have extra observer/fake_quant |
| # when has_relu is False |
| quant_type_to_prepare_expected_node_occurrence = { |
| QuantType.DYNAMIC: { |
| ns.call_module(torch.ao.quantization.PlaceholderObserver): 1, |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 1, |
| }, |
| # There should be 3 observers: after input, weight and activation. |
| # one more observer for torch.nn.Identity when there is no relu |
| QuantType.STATIC: { |
| ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3, |
| ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1, |
| }, |
| # There should be 3 observers: after input, weight and activation. |
| QuantType.QAT: { |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4, |
| }, |
| } |
| model = FuncLinear(use_bias, has_relu, f_relu) |
| if has_relu: |
| qlinear_fun = quant_type_to_qlinear_relu_fun[quant_type] |
| else: |
| qlinear_fun = quant_type_to_qlinear_fun[quant_type] |
| |
| if quant_type != QuantType.DYNAMIC: |
| num_dequantize = 1 |
| else: |
| # we will have an extra quantize_per_tensor_dynamic + dequantize for |
| # nn.Identity right now, but it will be fixed after we use |
| # backend_config to configure the default pt backend |
| num_dequantize = int(not has_relu) |
| |
| convert_node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0, |
| qlinear_fun: 1, |
| ns.call_method("dequantize"): num_dequantize if quant_type != QuantType.DYNAMIC else 0, |
| } |
| prepare_expected_node_occurrence = \ |
| quant_type_to_prepare_expected_node_occurrence[quant_type] |
| result_dict = self.checkGraphModeFxOp( |
| model, data, quant_type, qlinear_fun, |
| prepare_expected_node_occurrence=prepare_expected_node_occurrence, |
| expected_node_occurrence=convert_node_occurrence) |
| if quant_type != QuantType.DYNAMIC: |
| self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) |
| # Ensure packed weights in lowered models are folded |
| self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys()) |
| |
| @skipIfNoFBGEMM |
| def test_linear_dynamic_fp16(self): |
| with override_quantized_engine('fbgemm'): |
| class FuncLinear(torch.nn.Module): |
| def __init__(self, use_bias, has_relu, f_relu): |
| super().__init__() |
| self.w = torch.randn(4, 30) |
| self.b = torch.randn(4) |
| self.use_bias = use_bias |
| if has_relu: |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = torch.nn.Identity() |
| |
| def forward(self, x): |
| if self.use_bias: |
| x = F.linear(x, self.w, self.b) |
| else: |
| x = F.linear(x, self.w) |
| x = self.relu(x) |
| return x |
| |
| data = (torch.rand((1, 30), dtype=torch.float),) |
| options = itertools.product( |
| (True, False), # use_bias |
| (True, False), # has_relu |
| (True, False), # functional relu |
| (True, False), # is_reference |
| ) |
| for use_bias, has_relu, f_relu, is_reference in options: |
| model = FuncLinear(use_bias, has_relu, f_relu) |
| if is_reference: |
| qlinear_fun = ns.call_function(torch.nn.functional.linear) |
| else: |
| if has_relu: |
| qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16) |
| else: |
| qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16) |
| prepare_node_occurrence = { |
| # activation and weight |
| ns.call_module(torch.ao.quantization.PlaceholderObserver): 2 |
| } |
| convert_node_occurrence = { |
| qlinear_fun: 1, |
| # weight |
| ns.call_method("to"): 1 if is_reference else 0 |
| } |
| self.checkGraphModeFxOp( |
| model, data, QuantType.DYNAMIC, qlinear_fun, |
| is_reference=is_reference, |
| custom_qconfig_dict={"": float16_dynamic_qconfig}, |
| prepare_expected_node_occurrence=prepare_node_occurrence, |
| expected_node_occurrence=convert_node_occurrence) |
| |
| def test_linear_static_fp16(self): |
| class FuncLinear(torch.nn.Module): |
| def __init__(self, use_bias, has_relu, f_relu): |
| super().__init__() |
| self.w = torch.randn(4, 30) |
| self.b = torch.randn(4) |
| self.use_bias = use_bias |
| if has_relu: |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = torch.nn.Identity() |
| |
| def forward(self, x): |
| if self.use_bias: |
| x = F.linear(x, self.w, self.b) |
| else: |
| x = F.linear(x, self.w) |
| x = self.relu(x) |
| return x |
| |
| data = (torch.rand((1, 30), dtype=torch.float),) |
| options = itertools.product( |
| (True, False), # use_bias |
| (True, False), # has_relu |
| (True, False), # functional relu |
| (True, False), # is_reference |
| ) |
| backend_config = get_test_only_legacy_native_backend_config() |
| for use_bias, has_relu, f_relu, is_reference in options: |
| model = FuncLinear(use_bias, has_relu, f_relu) |
| linear_fun = ns.call_function(torch.nn.functional.linear) |
| # when has_relu is False, we are using an nn.Identity and |
| # we will insert observer/fake_quant for the output of nn.Identity since |
| # it is a copy node, that's why we have extra observer/fake_quant |
| # when has_relu is False |
| prepare_node_occurrence = { |
| # activation, weight, bias and output |
| ns.call_module(torch.ao.quantization.PlaceholderObserver): 3 + int(use_bias) + int(not has_relu), |
| } |
| # We have extra to and dequantize when is_reference is True |
| # and has_relu is False since when has_relu is False, we |
| # have an nn.Identity in the model, which is a CopyNode |
| # and we would add extra quant - dequant for CopyNode in |
| # reference patterns |
| convert_node_occurrence = { |
| # we don't support static fp16 ops, so the linear function |
| # is unfused |
| linear_fun: 1, |
| # activation, weight, bias and output |
| ns.call_method("to"): 3 + int(use_bias) + int(not has_relu and is_reference), |
| ns.call_method("dequantize"): 3 + int(use_bias) + int(not has_relu and is_reference) |
| } |
| self.checkGraphModeFxOp( |
| model, data, QuantType.DYNAMIC, linear_fun, |
| is_reference=is_reference, |
| custom_qconfig_dict={"": float16_static_qconfig}, |
| prepare_expected_node_occurrence=prepare_node_occurrence, |
| expected_node_occurrence=convert_node_occurrence, |
| backend_config=backend_config) |
| |
| @skipIfNoFBGEMM |
| def test_conv_module(self): |
| conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} |
| |
| class ConvWrapper(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| options = itertools.product([1, 2, 3], self.static_quant_types) |
| quantized_nodes = { |
| # dim |
| 1: ns.call_module(nnq.Conv1d), |
| 2: ns.call_module(nnq.Conv2d), |
| 3: ns.call_module(nnq.Conv3d), |
| } |
| for dim, quant_type in options: |
| self.checkGraphModeFxOp( |
| ConvWrapper(dim), self.img_data_dict[dim], quant_type, |
| quantized_nodes[dim]) |
| |
| @skipIfNoFBGEMM |
| def test_functional_conv(self): |
| with override_quantized_engine('fbgemm'): |
| """ Test for function conv and functional conv + relu |
| """ |
| convs = { |
| 1: torch.nn.functional.conv1d, |
| 2: torch.nn.functional.conv2d, |
| 3: torch.nn.functional.conv3d, |
| } |
| |
| class FuncConv(torch.nn.Module): |
| def __init__(self, dim, use_bias, has_relu, f_relu): |
| super().__init__() |
| self.dim = dim |
| self.w = torch.randn(tuple([3] * (dim + 2))) |
| self.b = torch.randn(3) if use_bias else None |
| self.stride = tuple([1] * dim) |
| self.padding = tuple([0] * dim) |
| self.dilation = tuple([1] * dim) |
| self.groups = 1 |
| self.use_bias = use_bias |
| if has_relu: |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = torch.nn.Identity() |
| |
| def forward(self, x): |
| x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) |
| x = self.relu(x) |
| return x |
| |
| quant_type_to_qconv_fun = { |
| QuantType.STATIC: { |
| 1: ns.call_function(torch.ops.quantized.conv1d), |
| 2: ns.call_function(torch.ops.quantized.conv2d), |
| 3: ns.call_function(torch.ops.quantized.conv3d) |
| }, |
| QuantType.QAT: { |
| 1: ns.call_function(torch.ops.quantized.conv1d), |
| 2: ns.call_function(torch.ops.quantized.conv2d), |
| 3: ns.call_function(torch.ops.quantized.conv3d) |
| }, |
| } |
| quant_type_to_qconv_relu_fun = { |
| QuantType.STATIC: { |
| 1: ns.call_function(torch.ops.quantized.conv1d_relu), |
| 2: ns.call_function(torch.ops.quantized.conv2d_relu), |
| 3: ns.call_function(torch.ops.quantized.conv3d_relu) |
| }, |
| QuantType.QAT: { |
| 1: ns.call_function(torch.ops.quantized.conv1d_relu), |
| 2: ns.call_function(torch.ops.quantized.conv2d_relu), |
| 3: ns.call_function(torch.ops.quantized.conv3d_relu) |
| }, |
| } |
| |
| options = itertools.product( |
| [1, 2, 3], # dims |
| self.static_quant_types, |
| (True, False), # use_bias |
| (True, False), # has_relu |
| (True, False), # functional relu |
| ) |
| for dim, quant_type, use_bias, has_relu, f_relu in options: |
| # when has_relu is False, we are using an nn.Identity and |
| # we will insert observer/fake_quant for the output of nn.Identity since |
| # it is a copy node, that's why we have extra observer/fake_quant |
| # when has_relu is False |
| quant_type_to_prepare_expected_node_occurrence = { |
| QuantType.DYNAMIC: {}, |
| # There should be 3 observers: after input, weight and activation. |
| QuantType.STATIC: { |
| ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3, |
| ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1, |
| }, |
| # There should be 3 observers: after input, weight and activation. |
| QuantType.QAT: { |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4, |
| }, |
| } |
| data_dims = [2, 3] + [4] * dim |
| data = (torch.randn(tuple(data_dims), dtype=torch.float),) |
| model = FuncConv(dim, use_bias, has_relu, f_relu) |
| if has_relu: |
| qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim] |
| else: |
| qconv_fun = quant_type_to_qconv_fun[quant_type][dim] |
| |
| convert_node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| qconv_fun: 1, |
| ns.call_method("dequantize"): 1 |
| } |
| prepare_expected_node_occurrence = \ |
| quant_type_to_prepare_expected_node_occurrence[quant_type] |
| result_dict = self.checkGraphModeFxOp( |
| model, data, quant_type, qconv_fun, |
| prepare_expected_node_occurrence=prepare_expected_node_occurrence, |
| expected_node_occurrence=convert_node_occurrence) |
| if quant_type != QuantType.DYNAMIC: |
| self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) |
| # Ensure packed weights in lowered models are folded |
| self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys()) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_conv_relu(self): |
| """tests for conv1d_relu/conv2d_relu/conv3d_relu""" |
| conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} |
| |
| class ConvNdRelu(torch.nn.Module): |
| def __init__(self, dim, inplace): |
| super().__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x): |
| return self.relu(self.conv(x)) |
| |
| class ConvNdFunctionalRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return F.relu(self.conv(x)) |
| |
| class ConvNdInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return F.relu(self.conv(x), True) |
| |
| options = itertools.product([1, 2, 3], self.static_quant_types) |
| quantized_nodes = { |
| # dim |
| 1: ns.call_module(nniq.ConvReLU1d), |
| 2: ns.call_module(nniq.ConvReLU2d), |
| 3: ns.call_module(nniq.ConvReLU3d), |
| } |
| for dim, quant_type in options: |
| for m in [ConvNdRelu(dim, True), |
| ConvNdRelu(dim, False), |
| ConvNdFunctionalRelu(dim), |
| ConvNdInplaceFunctionalRelu(dim)]: |
| self.checkGraphModeFxOp( |
| m, self.img_data_dict[dim], quant_type, |
| quantized_nodes[dim]) |
| |
| |
| def _test_binary_op_int8_impl(self, binary_op, ibinary_op, quantized_op): |
| data = (torch.randn(1, 1, 1, 1, dtype=torch.float), |
| torch.randn(1, 1, 1, 1, dtype=torch.float)) |
| options = itertools.product([True, False], [True, False], [True, False]) |
| quant_type = QuantType.STATIC |
| # testing for default int8 static quant |
| for is_inplace, is_scalar, is_reference in options: |
| if is_reference: |
| node_list = [ |
| ns.call_method("dequantize"), |
| ns.call_function(binary_op), |
| ns.call_function(torch.quantize_per_tensor) |
| ] |
| quantized_node = None |
| else: |
| node_list = None |
| quantized_node = ns.call_function(quantized_op) |
| |
| self.checkGraphModeFxOp( |
| BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, |
| quantized_node, expected_node_list=node_list, is_reference=is_reference) |
| # This tests the binary op should be quantized even when it is not feed with a |
| # quantized input |
| self.checkGraphModeFxOp( |
| BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), |
| data, quant_type, quantized_node, |
| expected_node_list=node_list, is_reference=is_reference) |
| |
| |
| def _test_binary_op_float16_impl(self, binary_op, ibinary_op): |
| data = (torch.randn(1, 1, 1, 1, dtype=torch.float), |
| torch.randn(1, 1, 1, 1, dtype=torch.float)) |
| quant_type = QuantType.STATIC |
| # testing for fp16 static quant |
| # we are producing fp16 patterns |
| options = itertools.product([True, False], [True, False]) |
| custom_qconfig_dict = { |
| "object_type": [(binary_op, float16_static_qconfig)] |
| } |
| backend_config = get_test_only_legacy_native_backend_config() |
| for is_inplace, is_scalar in options: |
| node_occurrence = { |
| # output_conv1, output_add1, output_add2 for scalar |
| # output_conv1, output_conv2, output_add1, output_add2 for non-scalar |
| ns.call_method("to"): 3 if is_scalar else 4 |
| } |
| self.checkGraphModeFxOp( |
| BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, |
| expected_node_occurrence=node_occurrence, |
| custom_qconfig_dict=custom_qconfig_dict, |
| backend_config=backend_config) |
| |
| node_occurrence = { |
| # input_add, output_add for scalar |
| # input_add1, input_add2, output_add for non-scalar |
| ns.call_method("to"): 2 if is_scalar else 3 |
| } |
| self.checkGraphModeFxOp( |
| BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, |
| expected_node_occurrence=node_occurrence, |
| custom_qconfig_dict=custom_qconfig_dict, |
| backend_config=backend_config) |
| |
| def _test_binary_op_relu_int8_impl(self, binary_op, ibinary_op, quantized_op): |
| data = (torch.rand((1, 1, 1, 1), dtype=torch.float), |
| torch.rand((1, 1, 1, 1), dtype=torch.float)) |
| quant_type = QuantType.STATIC |
| quantized_node = ns.call_function(quantized_op) |
| options = itertools.product( |
| [True, False], [nn.ReLU, F.relu, torch.relu], [True, False]) |
| for is_inplace_op, relu_callable, is_scalar in options: |
| model = BinaryOpRelu( |
| binary_op, ibinary_op, is_inplace_op, relu_callable, is_scalar) |
| self.checkGraphModeFxOp( |
| model, data, quant_type, quantized_node) |
| |
| def _test_binary_op_relu_float16_impl(self, binary_op, ibinary_op): |
| data = (torch.rand((1, 1, 1, 1), dtype=torch.float), |
| torch.rand((1, 1, 1, 1), dtype=torch.float)) |
| quant_type = QuantType.STATIC |
| options = itertools.product( |
| [True, False], [nn.ReLU, F.relu, torch.relu], [True, False]) |
| custom_qconfig_dict = { |
| "": float16_static_qconfig, |
| "object_type": [(torch.nn.Conv2d, None)] |
| } |
| backend_config = get_test_only_legacy_native_backend_config() |
| for is_inplace_op, is_functional_relu, is_scalar in options: |
| node_occurrence = { |
| ns.call_method("to"): 3 if is_scalar else 4 |
| } |
| model = BinaryOpRelu( |
| binary_op, ibinary_op, is_inplace_op, is_functional_relu, is_scalar) |
| self.checkGraphModeFxOp( |
| model, data, quant_type, custom_qconfig_dict=custom_qconfig_dict, |
| expected_node_occurrence=node_occurrence, |
| backend_config=backend_config) |
| |
| |
| @skipIfNoFBGEMM |
| def test_add(self): |
| self._test_binary_op_int8_impl( |
| operator.add, operator.iadd, torch.ops.quantized.add) |
| self._test_binary_op_float16_impl( |
| operator.add, operator.iadd) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_sub(self): |
| self._test_binary_op_float16_impl(operator.sub, operator.isub) |
| self._test_binary_op_float16_impl(torch.sub, None) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_div(self): |
| self._test_binary_op_float16_impl(operator.truediv, operator.itruediv) |
| self._test_binary_op_float16_impl(torch.div, None) |
| |
| @skipIfNoFBGEMM |
| def test_mul(self): |
| self._test_binary_op_int8_impl( |
| operator.mul, operator.imul, torch.ops.quantized.mul) |
| self._test_binary_op_float16_impl(operator.mul, operator.imul) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_sum(self): |
| class Sum(torch.nn.Module): |
| def forward(self, x): |
| x = torch.sum(x, [1], keepdim=True) |
| x = torch.sum(x, [1]) |
| return x |
| |
| data = torch.randn(1, 2, 3, 4, dtype=torch.float) |
| quant_type = QuantType.STATIC |
| # testing for fp16 static quant |
| # we are producing fp16 patterns |
| custom_qconfig_dict = { |
| "object_type": [(torch.sum, float16_static_qconfig)] |
| } |
| node_occurrence = { |
| # input_sum1, output_sum1, output_sum2 |
| ns.call_method("to"): 3 |
| } |
| self.checkGraphModeFxOp( |
| Sum(), data, quant_type, |
| expected_node_occurrence=node_occurrence, |
| custom_qconfig_dict=custom_qconfig_dict) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_bmm(self): |
| class BMMMethod(torch.nn.Module): |
| def forward(self, x, y): |
| return x.bmm(y) |
| |
| data = (torch.randn(1, 1, 1, dtype=torch.float), |
| torch.randn(1, 1, 1, dtype=torch.float)) |
| quant_type = QuantType.STATIC |
| # testing for fp16 static quant |
| # we are producing fp16 patterns |
| custom_qconfig_dict = { |
| "object_type": [(torch.bmm, float16_static_qconfig), |
| ("bmm", float16_static_qconfig)] |
| } |
| node_occurrence = { |
| # input_bmm1, input_bmm2, output_bmm |
| ns.call_method("to"): 3 |
| } |
| self.checkGraphModeFxOp( |
| BinaryOpNonQuantizedInput(torch.bmm, None, False, False), data, quant_type, |
| expected_node_occurrence=node_occurrence, |
| custom_qconfig_dict=custom_qconfig_dict) |
| |
| # TODO: support call_method("bmm") |
| # we can transform call_method("bmm") to call_function(torch.bmm) |
| # self.checkGraphModeFxOp( |
| # BMMMethod(), data, quant_type, |
| # expected_node_occurrence=node_occurrence, |
| # custom_qconfig_dict=custom_qconfig_dict, |
| # print_debug_info=True) |
| |
| @skipIfNoFBGEMM |
| def test_add_relu(self): |
| self._test_binary_op_relu_int8_impl( |
| operator.add, operator.iadd, torch.ops.quantized.add_relu) |
| self._test_binary_op_relu_float16_impl( |
| operator.add, operator.iadd) |
| |
| @skipIfNoFBGEMM |
| def test_add_relu_multiple_uses_of_relu(self): |
| class Sub(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU(inplace=True) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = Sub() |
| |
| def forward(self, x, y): |
| x = x + y |
| x = self.sub.relu(x) |
| x = x + y |
| x = self.sub.relu(x) |
| return x |
| |
| m = M().eval() |
| example_inputs = (torch.randn(3), torch.randn(3)) |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| m = convert_fx(m) |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_function(torch.ops.quantized.add_relu): 2, |
| ns.call_method("dequantize"): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| # check the model is scriptable |
| m = torch.jit.script(m) |
| # check the model is runnable |
| m(*example_inputs) |
| |
| @skipIfNoFBGEMM |
| def test_mul_relu(self): |
| self._test_binary_op_relu_int8_impl( |
| operator.mul, operator.imul, torch.ops.quantized.mul_relu) |
| self._test_binary_op_relu_float16_impl( |
| operator.mul, operator.imul) |
| |
| # TODO(future PR): make more generic |
| def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence): |
| qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} |
| mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) |
| self.checkGraphModuleNodes( |
| mp, expected_node_occurrence=expected_node_occurrence) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_qat(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = torch.add(x, 1.0) |
| x = self.conv1(x) |
| x = torch.add(x, 1.0) |
| x = torch.relu(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M() |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| expected_node_occurrence = { |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, |
| } |
| self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul_qat(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = torch.mul(x, 1.0) |
| x = self.conv1(x) |
| x = torch.mul(x, 1.0) |
| x = torch.relu(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M() |
| example_inputs = (torch.randn(1, 1, 1, 1),) |
| expected_node_occurrence = { |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, |
| } |
| self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) |
| |
| def test_int8_input_no_unnecessary_fq(self): |
| """ |
| If the inputs to the graph are quantized and the only node |
| does not need an activation observer, verifies that the |
| activation observer is not inserted. |
| """ |
| class M(nn.Module): |
| def __init__(self, scalar): |
| super().__init__() |
| self.scalar = scalar |
| self.add_func = torch.ao.nn.quantized.FloatFunctional() |
| |
| def forward(self, x): |
| return self.add_func.add_scalar(x, self.scalar) |
| |
| m = M(0.5) |
| mp = torch.ao.quantization.quantize_fx.prepare_qat_fx( |
| m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, |
| example_inputs=(torch.randn(1),), |
| prepare_custom_config={"input_quantized_idxs": [0]}) |
| expected_node_occurrence = { |
| ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1, |
| } |
| self.checkGraphModuleNodes( |
| mp, expected_node_occurrence=expected_node_occurrence) |
| |
| @skipIfNoFBGEMM |
| def test_cat(self): |
| """ quantization of the output of cat will depend on the |
| input of cat. we only quantize the output of cat when its inputs are quantized. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| return torch.cat([x, y], 1) |
| |
| example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float), |
| torch.randn(1, 2, 5, 5, dtype=torch.float)) |
| quantized_node = ns.call_function(torch.cat) |
| options = itertools.product(self.static_quant_types, [True, False]) |
| for quant_type, is_reference in options: |
| if is_reference: |
| converted_node_list = [ |
| ns.call_method("dequantize"), |
| ns.call_function(torch.cat), |
| ns.call_function(torch.quantize_per_tensor) |
| ] |
| converted_node_occurrence = { |
| # inputs and outputs of the two conv, and output of cat |
| ns.call_method("dequantize"): 5, |
| ns.call_function(torch.cat): 1, |
| # inputs and outputs of the two conv, and output of cat |
| ns.call_function(torch.quantize_per_tensor): 5, |
| } |
| else: |
| converted_node_list = None |
| converted_node_occurrence = { |
| # output of cat |
| ns.call_method("dequantize"): 1, |
| ns.call_function(torch.cat): 1, |
| # for two inputs |
| ns.call_function(torch.quantize_per_tensor): 2, |
| } |
| |
| self.checkGraphModeFxOp( |
| M(), |
| example_inputs, |
| quant_type, |
| quantized_node, |
| expected_node_list=converted_node_list, |
| expected_node_occurrence=converted_node_occurrence, |
| is_reference=is_reference) |
| |
| # check cat is using the same observer for input and output |
| m = M().eval() |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| # two inputs and one output of torch.cat are using same observer, so we have |
| # 2 observers that's replicated |
| all_observers = len(dict(m.named_modules(remove_duplicate=False))) |
| distinct_observers = len(dict(m.named_modules())) |
| self.assertEqual(all_observers, distinct_observers + 2) |
| # make sure the converted model runs |
| m = convert_fx(m) |
| m(*example_inputs) |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm(self): |
| bn_module = { |
| # TODO: quantized batchnorm 1d module is missing |
| # 1 : torch.nn.BatchNorm1d, |
| 2 : torch.nn.BatchNorm2d, |
| 3 : torch.nn.BatchNorm3d, |
| } |
| |
| class M(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return self.bn(x) |
| |
| options = itertools.product(self.static_quant_types, [2, 3], [True, False]) |
| quantized_nodes = { |
| False: { |
| # 1: ns.call_module(nnq.BatchNorm1d), |
| 2: ns.call_module(nnq.BatchNorm2d), |
| 3: ns.call_module(nnq.BatchNorm3d), |
| }, |
| True: { |
| # 1: ns.call_module(nn.BatchNorm1d), |
| 2: ns.call_module(nn.BatchNorm2d), |
| 3: ns.call_module(nn.BatchNorm3d), |
| } |
| } |
| for quant_type, dim, is_reference in options: |
| self.checkGraphModeFxOp( |
| M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[is_reference][dim], is_reference=is_reference) |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm_relu(self): |
| bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} |
| |
| class BNRelu(torch.nn.Module): |
| def __init__(self, dim, inplace): |
| super().__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| self.relu = torch.nn.ReLU(inplace=inplace) |
| |
| def forward(self, x): |
| return self.relu(self.bn(x)) |
| |
| class BNFuncRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return F.relu(self.bn(x), False) |
| |
| class BNFuncInplaceRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return F.relu(self.bn(x), True) |
| |
| options = itertools.product(self.static_quant_types, [2, 3], [True, False]) |
| quantized_nodes = { |
| True: { |
| 2: ns.call_module(nni.BNReLU2d), |
| 3: ns.call_module(nni.BNReLU3d), |
| }, |
| False: { |
| 2: ns.call_module(nniq.BNReLU2d), |
| 3: ns.call_module(nniq.BNReLU3d), |
| } |
| } |
| for quant_type, dim, is_reference in options: |
| for instance in [BNRelu(dim, True), BNRelu(dim, False), |
| BNFuncRelu(dim), BNFuncInplaceRelu(dim)]: |
| self.checkGraphModeFxOp( |
| instance, self.img_data_dict[dim], quant_type, |
| quantized_nodes[is_reference][dim], is_reference=is_reference) |
| |
| def _test_activation_impl( |
| self, float_module, float_op, quantized_module, quantized_op): |
| ''' Test for activation op(with inplace options), float_op can be |
| torch op or functional op |
| ''' |
| class M(torch.nn.Module): |
| def __init__(self, is_module, inplace): |
| super().__init__() |
| self.is_module = is_module |
| self.inplace = inplace |
| if self.is_module: |
| self.op = float_module(self.inplace) |
| else: |
| self.op = float_op |
| |
| def forward(self, input): |
| if self.is_module: |
| return self.op(input) |
| else: |
| return self.op(input, self.inplace) |
| |
| options = itertools.product([True, False], [True, False], self.static_quant_types, [True, False]) |
| quantized_nodes = { |
| # is_module |
| True: { |
| # is_reference |
| True: ns.call_module(float_module), |
| False: ns.call_module(quantized_module), |
| }, |
| False: { |
| True: ns.call_function(float_op), |
| False: ns.call_function(quantized_op), |
| } |
| } |
| |
| for is_module, is_inplace, quant_type, is_reference in options: |
| self.checkGraphModeFxOp( |
| M(is_module, is_inplace), self.img_data_2d, |
| quant_type, quantized_nodes[is_module][is_reference], is_reference=is_reference) |
| |
| def test_hardswish(self): |
| self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish) |
| |
| def test_elu(self): |
| self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu) |
| |
| def test_leaky_relu(self): |
| self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) |
| |
| def test_prelu(self): |
| class M(torch.nn.Module): |
| def __init__(self, num_param: int): |
| super().__init__() |
| self.op = torch.nn.PReLU(num_parameters=num_param) |
| |
| def forward(self, input): |
| return self.op(input) |
| |
| X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]] |
| options = itertools.product([1, 4], self.static_quant_types, [True, False]) |
| quantized_nodes = { |
| # is_reference |
| True: ns.call_module(torch.nn.PReLU), |
| False: ns.call_module(torch.ao.nn.quantized.PReLU), |
| } |
| |
| for num_parameter, quant_type, is_reference in options: |
| self.checkGraphModeFxOp( |
| M(num_parameter), X, quant_type, quantized_nodes[is_reference], |
| is_reference=is_reference) |
| |
| def _test_norm_impl( |
| self, float_module, float_op, op_args, data, quantized_module, quantized_op, |
| skip_op_arg_for_functional=False): |
| ''' Test for normalization op, float_op can be torch op or functional op, |
| op_args is a list of positional argument for the module/op |
| ''' |
| class M(torch.nn.Module): |
| def __init__(self, is_module): |
| super().__init__() |
| self.is_module = is_module |
| if self.is_module: |
| self.op = float_module(*op_args) |
| else: |
| self.op = float_op |
| |
| def forward(self, input): |
| if self.is_module: |
| return self.op(input) |
| else: |
| args = [input] |
| if not skip_op_arg_for_functional: |
| args += op_args |
| return self.op(*args) |
| |
| options = itertools.product([True, False], self.static_quant_types) |
| quantized_nodes = { |
| # is_module |
| True: ns.call_module(quantized_module), |
| False: ns.call_function(quantized_op), |
| } |
| |
| for is_module, quant_type in options: |
| self.checkGraphModeFxOp( |
| M(is_module), data, quant_type, quantized_nodes[is_module]) |
| |
| def _test_norm_float16_impl( |
| self, float_module, float_op, op_args, data, |
| skip_op_arg_for_functional=False): |
| ''' Test for normalization op, float_op can be torch op or functional op, |
| op_args is a list of positional argument for the module/op |
| ''' |
| class M(torch.nn.Module): |
| def __init__(self, is_module): |
| super().__init__() |
| self.is_module = is_module |
| if self.is_module: |
| self.op = float_module(*op_args) |
| else: |
| self.op = float_op |
| |
| def forward(self, input): |
| if self.is_module: |
| return self.op(input) |
| else: |
| args = [input] |
| if not skip_op_arg_for_functional: |
| args += op_args |
| return self.op(*args) |
| |
| options = itertools.product([True, False], self.static_quant_types) |
| qconfig_dict = { |
| "object_type": [ |
| (float_module, float16_static_qconfig), |
| (float_op, float16_static_qconfig) |
| ] |
| } |
| node_occurrence = { |
| ns.call_method("to"): 2 |
| } |
| for is_module, quant_type in options: |
| self.checkGraphModeFxOp( |
| M(is_module), data, quant_type, custom_qconfig_dict=qconfig_dict, expected_node_occurrence=node_occurrence) |
| |
| def test_layer_norm(self): |
| data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) |
| self._test_norm_impl( |
| nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm) |
| |
| def test_instance_norm(self): |
| data_1d = (torch.rand((1, 4, 5), dtype=torch.float),) |
| data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),) |
| data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),) |
| data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d} |
| instance_norm_modules = {1 : nn.InstanceNorm1d, |
| 2 : nn.InstanceNorm2d, |
| 3 : nn.InstanceNorm3d} |
| quantized_instance_norm_modules = { |
| 1 : nnq.InstanceNorm1d, |
| 2 : nnq.InstanceNorm2d, |
| 3 : nnq.InstanceNorm3d |
| } |
| for dim in [1, 2, 3]: |
| data = data_dict[dim] |
| module = instance_norm_modules[dim] |
| quantized_module = quantized_instance_norm_modules[dim] |
| self._test_norm_impl( |
| module, F.instance_norm, [4], data, |
| quantized_module, torch.ops.quantized.instance_norm, |
| skip_op_arg_for_functional=True) |
| |
| def test_norm_weight_bias(self): |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = Linear() |
| self.scale = torch.randn(5, 5) |
| self.bias = torch.randn(5, 5) |
| |
| def forward(self, x): |
| x1 = self.mods1(x) |
| y = F.layer_norm(x1, [5, 5], weight=self.scale, bias=self.bias) |
| return y |
| |
| model = M() |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_function(torch.ops.quantized.linear): 1, |
| ns.call_function(torch.ops.quantized.layer_norm): 1, |
| ns.call_method("dequantize"): 1, |
| } |
| |
| self.checkGraphModeFxOp( |
| model, |
| (torch.rand(5, 5),), |
| QuantType.STATIC, |
| expected_node_occurrence=expected_occurrence, |
| custom_qconfig_dict=get_default_qconfig_mapping().to_dict() |
| ) |
| |
| def _test_default_node_quant_handler_ops( |
| self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None |
| ): |
| class M(torch.nn.Module): |
| def __init__(self, mod, func): |
| super().__init__() |
| self.module = mod() |
| self.functional = func |
| |
| def forward(self, x): |
| x = self.module(x) |
| x = self.functional(x) |
| return x |
| |
| if node_list is None: |
| node_list = [] |
| if additional_quant_pattern_dict is None: |
| additional_quant_pattern_dict = {} |
| |
| data = torch.randn((2, 2, 2, 2)) |
| quant_type = QuantType.STATIC |
| prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict} |
| qconfig_dict = {"": qconfig} |
| |
| m = M(module, functional).eval() |
| m_prep = prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict) |
| m_prep(data) |
| convert_fn = convert_to_reference_fx if is_reference else convert_fx |
| m_quant = convert_fn(m_prep, is_reference=is_reference) |
| m_quant(data) |
| |
| self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) |
| |
| @unittest.skip("TODO: reenable with backend_config api") |
| def test_gelu_normal(self): |
| module = torch.nn.GELU |
| functional = torch.nn.functional.gelu |
| qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") |
| is_reference = False |
| node_list = [ |
| ns.call_module(module), |
| ns.call_function(functional), |
| ] |
| self._test_default_node_quant_handler_ops( |
| module, functional, qconfig, is_reference, node_list) |
| |
| @unittest.skip("TODO: reenable with backend_config api") |
| def test_softmax_normal(self): |
| module = torch.nn.Softmax |
| functional = torch.nn.functional.softmax |
| qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") |
| is_reference = False |
| node_list = [ |
| ns.call_module(torch.ao.nn.quantized.Softmax), |
| ns.call_function(functional), |
| ] |
| self._test_default_node_quant_handler_ops( |
| module, functional, qconfig, is_reference, node_list) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_gelu_reference(self): |
| module = torch.nn.GELU |
| functional = torch.nn.functional.gelu |
| qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") |
| is_reference = True |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ns.call_module(module), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize'), |
| ns.call_function(functional), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize') |
| ] |
| # TODO: change these to use backend_config |
| additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler, |
| torch.nn.functional.gelu: DefaultNodeQuantizeHandler} |
| self._test_default_node_quant_handler_ops( |
| module, functional, qconfig, is_reference, node_list, additional_patterns) |
| |
| self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, |
| additional_quant_pattern_dict=self.common_quant_patterns) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_softmax_reference(self): |
| module = torch.nn.Softmax |
| functional = torch.nn.functional.softmax |
| qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") |
| is_reference = True |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ns.call_module(module), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize'), |
| ns.call_function(functional), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize') |
| ] |
| additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler, |
| torch.nn.functional.softmax: DefaultNodeQuantizeHandler} |
| self._test_default_node_quant_handler_ops( |
| module, functional, qconfig, is_reference, node_list, additional_patterns) |
| |
| self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, |
| additional_quant_pattern_dict=self.common_quant_patterns) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_silu_reference(self): |
| module = torch.nn.SiLU |
| functional = torch.nn.functional.silu |
| qconfig = float16_static_qconfig |
| is_reference = True |
| node_list = [ |
| ns.call_method("to"), |
| ns.call_method("dequantize"), |
| ns.call_module(module), |
| ns.call_method("to"), |
| ns.call_method('dequantize'), |
| ns.call_function(functional), |
| ns.call_method("to"), |
| ns.call_method('dequantize') |
| ] |
| self._test_default_node_quant_handler_ops( |
| module, functional, qconfig, is_reference, node_list) |
| |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ns.call_module(module), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ns.call_function(functional), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize") |
| ] |
| self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, |
| additional_quant_pattern_dict=self.common_quant_patterns) |
| |
| @unittest.skip("This is no longer needed right now, can enable later with new api") |
| def test_mish_reference(self): |
| module = torch.nn.Mish |
| functional = torch.nn.functional.mish |
| qconfig = float16_static_qconfig |
| is_reference = True |
| node_list = [ |
| ns.call_method("to"), |
| ns.call_method("dequantize"), |
| ns.call_module(module), |
| ns.call_method("to"), |
| ns.call_method('dequantize'), |
| ns.call_function(functional), |
| ns.call_method("to"), |
| ns.call_method('dequantize') |
| ] |
| self._test_default_node_quant_handler_ops( |
| module, functional, qconfig, is_reference, node_list) |
| |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ns.call_module(module), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize"), |
| ns.call_function(functional), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize") |
| ] |
| self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, |
| additional_quant_pattern_dict=self.common_quant_patterns) |
| |
| def test_bmm_int_reference(self): |
| """ int8 is not supported for bmm so we won't produce reference |
| pattern for it |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bmm = torch.bmm |
| |
| def forward(self, x, y): |
| out = self.bmm(x, y) |
| return out |
| |
| data_x = torch.randn((2, 2, 2,)) |
| data_y = torch.randn((2, 2, 2,)) |
| example_inputs = (data_x, data_y) |
| qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} |
| is_reference = True |
| node_list = [ |
| ns.call_function(torch.bmm), |
| ] |
| |
| m = M().eval() |
| m_prep = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m_prep(*example_inputs) |
| convert_fn = convert_to_reference_fx if is_reference else convert_fx |
| m_quant = convert_fn(m_prep) |
| m_quant(*example_inputs) |
| |
| self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) |
| |
| @skipIfNoFBGEMM |
| def test_clamp(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu6 = torch.nn.ReLU6() |
| self.relu6_ = torch.nn.ReLU6(True) |
| self.hardtanh = torch.nn.Hardtanh() |
| self.hardtanh_ = torch.nn.Hardtanh(inplace=True) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.relu6(x) |
| self.relu6_(x) |
| x = F.relu6(x) |
| x = torch.clamp(x, -3, 3) |
| x = x.clamp(-2.5, 2.5) |
| # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready |
| x = self.hardtanh(x) |
| self.hardtanh_(x) |
| x = F.hardtanh(x) |
| return x |
| |
| data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) |
| # list of node that should occur in order |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method('dequantize') |
| ] |
| for quant_type in self.static_quant_types: |
| self.checkGraphModeFxOp( |
| M(), data, quant_type, expected_node_list=node_list) |
| |
| def test_fixed_qparams_ops_fp16(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sigmoid = torch.nn.Sigmoid() |
| self.tanh = torch.nn.Tanh() |
| |
| def forward(self, x): |
| x = self.sigmoid(x) |
| x = torch.sigmoid(x) |
| x = x.sigmoid() |
| x = self.tanh(x) |
| x = torch.tanh(x) |
| x = x.tanh() |
| return x |
| |
| data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) |
| quant_type = QuantType.STATIC |
| # TODO: use get_default_qconfig_mapping once it handles fp16 |
| qconfig_mapping = QConfigMapping().set_global(float16_static_qconfig) |
| backend_config = get_test_only_legacy_native_backend_config() |
| node_occurrence = { |
| ns.call_method("to"): 7 |
| } |
| self.checkGraphModeFxOp( |
| M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, |
| expected_node_occurrence=node_occurrence, |
| backend_config=backend_config) |
| |
| def test_fixed_qparams_ops_qint8(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sigmoid = torch.nn.Sigmoid() |
| self.tanh = torch.nn.Tanh() |
| |
| def forward(self, x): |
| x = self.sigmoid(x) |
| x = torch.sigmoid(x) |
| x = x.sigmoid() |
| x = self.tanh(x) |
| x = torch.tanh(x) |
| x = x.tanh() |
| return x |
| |
| data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) |
| quant_type = QuantType.STATIC |
| qconfig = torch.ao.quantization.QConfig( |
| activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8), |
| weight=default_weight_observer) |
| qconfig_mapping = get_default_qconfig_mapping().set_global(qconfig) |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 7, |
| ns.call_method("dequantize"): 7 |
| } |
| self.checkGraphModeFxOp( |
| M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, |
| expected_node_occurrence=node_occurrence, is_reference=True) |
| |
| def test_fixed_qparams_ops_wrong_qconfig(self): |
| """ Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sigmoid = torch.nn.Sigmoid() |
| self.tanh = torch.nn.Tanh() |
| |
| def forward(self, x): |
| x = self.sigmoid(x) |
| x = torch.sigmoid(x) |
| x = x.sigmoid() |
| x = self.tanh(x) |
| x = torch.tanh(x) |
| x = x.tanh() |
| return x |
| |
| data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) |
| qconfig_mapping = QConfigMapping().set_global(default_qconfig) |
| m = M().eval() |
| node_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 0, |
| ns.call_method("dequantize"): 0, |
| } |
| self.checkGraphModeFxOp( |
| m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping, |
| expected_node_occurrence=node_occurrence, is_reference=True) |
| self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid)) |
| self.assertTrue(isinstance(m.tanh, torch.nn.Tanh)) |
| |
| @skipIfNoFBGEMM |
| def test_general_shape_ops(self): |
| """ A test that checks dequantize will be swapped for |
| all supported general shape ops like aten::flatten |
| without actually checking for execution of these ops |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) |
| self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) |
| self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) |
| self.dropout = torch.nn.Dropout() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| self.conv2 = torch.nn.Conv2d(3, 3, 3) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| # add_scalar |
| x = x + 3 |
| # mul_scalar |
| x = x * 3 |
| # add_scalar_out |
| x += 3 |
| # mul_scalar_out |
| x *= 3 |
| # add_scalar_relu |
| x = x + 3 |
| x = F.relu(x) |
| # add_scalar_relu_out |
| x += 3 |
| x = F.relu(x) |
| # mul_scalar_relu |
| x = x * 3 |
| x = F.relu(x) |
| # mul_scalar_relu_out |
| x *= 3 |
| x = F.relu(x) |
| x = self.maxpool1d(x) |
| x = self.maxpool2d(x) |
| x = self.maxpool3d(x) |
| x = torch.flatten(x) |
| x = x.reshape([-1]) |
| x = x.resize_(1, 1, x) |
| x = x.view(-1) |
| # prim::ListConstruct |
| xs = [x, x] |
| # prim::ListUnpack |
| x, y = xs |
| # prim::TupleConstruct |
| xs = (x, x) |
| # prim::TupleUnpack |
| x, y = xs |
| x = x.transpose(1, 2) |
| x = x.contiguous() |
| # chunk is not supported since observer only supports |
| # observing single Tensor currently |
| x, y = torch.chunk(x, 2) |
| x = F.dropout(x) |
| x = self.dropout(x) |
| x = x.permute(0, 2, 3, 1) |
| x = x.repeat_interleave(3, 1) |
| x = torch.repeat_interleave(x, 3, 1) |
| x = self.relu(x) |
| x = F.relu(x) |
| x = F.relu(x, inplace=True) |
| x = x.relu() |
| x.relu_() |
| x = x.squeeze(0) |
| x.squeeze_(0) |
| x = torch.squeeze(x, 0) |
| x = x.unsqueeze(0) |
| x.unsqueeze_(0) |
| x = torch.unsqueeze(x, 0) |
| x = x.detach() |
| x.detach_() |
| x = x.repeat(4, 2) |
| y = [] |
| y.append(x) |
| z = torch.stack(y, 0) |
| z = [z, z] |
| x, _ = z |
| x = self.conv2(x) |
| return x |
| |
| example_inputs = (torch.rand(1, 3, 10, 10),) |
| # This model is not executable since we just put all ops |
| # in the same forward |
| m = M().eval() |
| qconfig_dict = {'': default_qconfig} |
| prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| # not runnable |
| quantized = convert_fx(prepared) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers and also successfully fused two quantized::conv2d |
| # patterns |
| # one quantize_per_tensor for input |
| # check exact counts of quantize and dequantize |
| count_check = { |
| # input of conv and two outputs of getitem |
| ns.call_function(torch.quantize_per_tensor) : 2, |
| # output of the model and two outputs of getitem |
| ns.call_method('dequantize') : 2 |
| } |
| order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized, |
| expected_node_occurrence=count_check, |
| expected_node_list=order_check) |
| |
| |
| # Checking the is_reference output |
| m = M().eval() |
| qconfig_dict = {'': default_qconfig} |
| prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| # not runnable |
| quantized = convert_to_reference_fx(prepared) |
| |
| |
| @skipIfNoFBGEMM |
| def test_ave_pool_with_custom_cfg(self): |
| """ A test that checks correct patterns are produced for |
| avg_pool2d with customized config |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.avg_pool2d = torch.nn.AvgPool2d(3) |
| |
| |
| def forward(self, x): |
| x = self.avg_pool2d(x) |
| return x |
| |
| # This model is not executable since we just put all ops |
| # in the same forward |
| m = M().eval() |
| # nothing to fuse so skipping the fuse step |
| qconfig_dict = {'': default_qconfig} |
| example_inputs = (torch.randn(1, 3, 3, 3),) |
| prepared = prepare_fx( |
| m, qconfig_dict, example_inputs=example_inputs, |
| prepare_custom_config={"input_quantized_idxs": [0]}) |
| |
| # not runnable |
| quantized = convert_fx(prepared) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers |
| # check exact counts of quantize and dequantize |
| count_check = { |
| ns.call_method('dequantize') : 1 |
| } |
| order_check = [ |
| ns.call_module(nn.AvgPool2d), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized, |
| expected_node_occurrence=count_check, |
| expected_node_list=order_check) |
| |
| @skipIfNoFBGEMM |
| def test_general_value_ops(self): |
| """ A test that checks correct patterns are produced for |
| all supported general value ops like aten::avg_pool2d \ |
| without actually checking for execution of these ops |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.avg_pool1d = torch.nn.AvgPool1d(3) |
| self.avg_pool2d = torch.nn.AvgPool2d(3) |
| self.avg_pool3d = torch.nn.AvgPool3d(3) |
| self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1) |
| self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.avg_pool1d(x) |
| x = self.avg_pool2d(x) |
| x = self.avg_pool3d(x) |
| x = self.adaptive_avg_pool1d(x) |
| x = self.adaptive_avg_pool2d(x) |
| x = self.adaptive_avg_pool3d(x) |
| x = F.avg_pool1d(x, 3) |
| x = F.avg_pool2d(x, 3) |
| x = F.avg_pool3d(x, 3) |
| x = F.adaptive_avg_pool1d(x, (1)) |
| x = F.adaptive_avg_pool2d(x, (1, 1)) |
| x = F.adaptive_avg_pool3d(x, (1, 1, 1)) |
| x = torch.mean(x) |
| x = torch.mean(x, [2, 3], False) |
| x = x.mean() |
| x = x.mean([2, 3], True) |
| x = F.interpolate(x, 4, mode='nearest') |
| x = F.interpolate(x, 4, mode='linear') |
| x = self.conv(x) |
| return x |
| |
| # This model is not executable since we just put all ops |
| # in the same forward |
| m = M().eval() |
| # nothing to fuse so skipping the fuse step |
| qconfig_dict = {'': default_qconfig} |
| example_inputs = (torch.randn(1, 3, 3, 3),) |
| prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| # not runnable |
| quantized = convert_fx(prepared) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers |
| # check exact counts of quantize and dequantize |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_method('dequantize') : 1 |
| } |
| order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized, |
| expected_node_occurrence=count_check, |
| expected_node_list=order_check) |
| |
| def test_copy_node_fp32_input(self): |
| """ CopyNode works for both fp32 and int8 inputs, this is a test to make |
| sure that a CopyNode can be successfully quantized in both cases |
| """ |
| class M(torch.nn.Module): |
| def forward(self, x): |
| x = x.relu() |
| return x |
| |
| m = M().eval() |
| m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),)) |
| m = convert_fx(m) |
| # make sure it runs |
| m(torch.rand(1)) |
| |
| def test_getitem(self): |
| """ Make sure we only insert observer for getitem if the following node is matched |
| or needs to be quantized |
| """ |
| class M(torch.nn.Module): |
| def forward(self, xs): |
| x = xs[0] |
| return x |
| |
| m = M().eval() |
| example_inputs = (torch.rand(1, 2),) |
| qconfig_mapping = get_default_qconfig_mapping() |
| m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) |
| self.checkGraphModuleNodes(m, expected_node_occurrence={ |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 0 |
| }) |
| m = convert_fx(m) |
| m(*example_inputs) |
| |
| class M2(torch.nn.Module): |
| def forward(self, xs): |
| x = xs[0] |
| x = torch.sigmoid(x) |
| return x |
| |
| m2 = M2().eval() |
| example_inputs = ([torch.rand(1, 2)],) |
| qconfig_mapping = get_default_qconfig_mapping() |
| m2 = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs) |
| self.checkGraphModuleNodes(m2, expected_node_occurrence={ |
| ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 |
| }) |
| m2 = convert_fx(m2) |
| self.checkGraphModuleNodes(m2, expected_node_list=[ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize") |
| ]) |
| m2(*example_inputs) |
| |
| # testing prepare recognizes non-Tensor input for getitem |
| class M3(torch.nn.Module): |
| def forward(self, x): |
| s = x.shape |
| n, c = s[:2] |
| x = torch.sigmoid(x) |
| return x |
| |
| m3 = M3().eval() |
| example_inputs = (torch.rand(1, 2, 3, 4),) |
| qconfig_mapping = get_default_qconfig_mapping() |
| m3 = prepare_fx(m3, qconfig_mapping, example_inputs=example_inputs) |
| self.checkGraphModuleNodes(m3, expected_node_occurrence={ |
| ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 |
| }) |
| m3 = convert_fx(m3) |
| self.checkGraphModuleNodes(m3, expected_node_list=[ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method("dequantize") |
| ]) |
| m3(*example_inputs) |
| |
| |
| @skipIfNoFBGEMM |
| def test_fixed_qparams_ops(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.sigmoid = torch.nn.Sigmoid() |
| self.hardsigmoid = torch.nn.Hardsigmoid() |
| self.tanh = torch.nn.Tanh() |
| self.softmax = torch.nn.Softmax(dim=0) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| # F.sigmoid is deprecated |
| x = self.sigmoid(x) |
| x = torch.sigmoid(x) |
| x = x.sigmoid() |
| x = self.hardsigmoid(x) |
| x = F.hardsigmoid(x) |
| x = F.hardsigmoid(x, inplace=True) |
| x = self.tanh(x) |
| # F.tanh is deprecated |
| x = torch.tanh(x) |
| x = x.tanh() |
| # TODO(future PR): handle F.softmax |
| x = self.softmax(x) |
| return x |
| |
| for eval_mode in [True, False]: |
| # This model is not executable since we just put all ops |
| # in the same forward |
| m = M() |
| if eval_mode: |
| m.eval() |
| qconfig_mapping = get_default_qconfig_mapping() |
| prepare = prepare_fx |
| fq_count = 10 |
| else: |
| m.train() |
| qconfig_mapping = get_default_qat_qconfig_mapping() |
| prepare = prepare_qat_fx |
| fq_count = 10 |
| # nothing to fuse so skipping the fuse step |
| m_copy = copy.deepcopy(m) |
| example_inputs = (torch.rand(3, 3, 3, 3),) |
| prepared = prepare(m, qconfig_mapping, example_inputs=example_inputs) |
| prepared_copy = copy.deepcopy(prepared) |
| # check that prepare does not change model result |
| if eval_mode: |
| self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs)) |
| # check the correct number of activation_post_process is inserted |
| expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize |
| count_check = { |
| ns.call_module(expected_activation_post_process) : fq_count, |
| } |
| self.checkGraphModuleNodes( |
| prepared, |
| expected_node_occurrence=count_check) |
| # not runnable |
| quantized = convert_fx(prepared) |
| quantized_reference = convert_to_reference_fx(prepared_copy) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers |
| # check exact counts of quantize and dequantize |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_method('dequantize') : 1 |
| } |
| order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nn.Sigmoid), |
| ns.call_module(nnq.Softmax), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized, |
| expected_node_occurrence=count_check, |
| expected_node_list=order_check) |
| |
| reference_count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 12, |
| ns.call_method('dequantize') : 12 |
| } |
| reference_order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize'), |
| ns.call_module(nnqr.Conv2d), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize'), |
| ns.call_module(nn.Sigmoid), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize'), |
| ns.call_module(nn.Softmax), |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized_reference, |
| expected_node_occurrence=reference_count_check, |
| expected_node_list=reference_order_check) |
| |
| # Verify that softmax scale and zero_point are correct |
| self.assertTrue(quantized.softmax.scale - (1.0 / 256) <= 1e-8) |
| self.assertTrue(quantized.softmax.zero_point == 0) |
| |
| def test_float_functional(self): |
| class TorchAdd(nn.Module): |
| """Wrapper around torch.add so that all ops can be found at build""" |
| def __init__(self): |
| super().__init__() |
| self.add_func = nnq.FloatFunctional() |
| |
| def forward(self, x, y): |
| return self.add_func.add(x, y) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.ff1 = TorchAdd() |
| self.ff2 = nnq.FloatFunctional() |
| self.ff3 = nnq.FloatFunctional() |
| self.ff4 = nnq.FloatFunctional() |
| self.ff5 = nnq.FloatFunctional() |
| self.ff6 = nnq.FloatFunctional() |
| |
| def forward(self, x): |
| x = self.ff1(x, x) |
| x = self.ff2.add_scalar(x, 3) |
| x = self.ff3.mul(x, x) |
| x = self.ff4.mul_scalar(x, 3) |
| x = self.ff5.add_relu(x, x) |
| x = self.ff6.cat([x]) |
| return x |
| |
| example_inputs = (torch.rand(3, 3),) |
| # Note: QAT test succeeded by chance, to make it actually work |
| # we need to fix eager mode FloatFunctional by removing |
| # activation_post_process in add_scalar and mul_scalar |
| for quant_type in self.static_quant_types: |
| m = M() |
| ref_m = torch.ao.quantization.QuantWrapper(M()) |
| is_qat = quant_type == QuantType.QAT |
| if is_qat: |
| m.train() |
| ref_m.train() |
| qconfig = default_qat_qconfig |
| expected_act_post_process = torch.ao.quantization.FakeQuantize |
| else: |
| m.eval() |
| ref_m.eval() |
| qconfig = default_qconfig |
| expected_act_post_process = torch.ao.quantization.MinMaxObserver |
| |
| prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx |
| qconfig_dict = {"": qconfig} |
| m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs) |
| node_occurrence = { |
| ns.call_module(expected_act_post_process): 7, |
| ns.call_module(torch.ao.nn.quantized.FloatFunctional): 0 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| m(*example_inputs) |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_function(torch.ops.quantized.mul), |
| ns.call_function(torch.ops.quantized.mul), |
| ns.call_function(torch.ops.quantized.add_relu), |
| ns.call_function(torch.cat), |
| ns.call_method('dequantize') |
| ] |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| # make sure numerics match with eager mode |
| ref_m.qconfig = qconfig |
| prepare_function = prepare_qat if is_qat else prepare |
| ref_m = prepare_function(ref_m) |
| ref_m(*example_inputs) |
| ref_m = convert(ref_m) |
| # FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar |
| # self.assertEqual(m(data), ref_m(data)) |
| |
| def test_embedding(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) |
| |
| def forward(self, indices): |
| return self.emb(indices) |
| |
| for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]: |
| model = M().eval() |
| indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) |
| example_inputs = (indices,) |
| quantized_node = ns.call_module(nnq.Embedding) |
| configs = [ |
| (qconfig_type, ns.call_module(nnq.Embedding)), |
| (None, ns.call_module(nn.Embedding)), |
| (default_qconfig, ns.call_module(nn.Embedding)), |
| ] |
| |
| for qconfig, node in configs: |
| qconfig_dict = {"": qconfig} |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| self.checkGraphModuleNodes(m, expected_node_occurrence={ |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 0 |
| }) |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node=node) |
| # make sure it runs |
| m(*example_inputs) |
| |
| def test_embedding_bag(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True) |
| |
| def forward(self, indices, offsets): |
| return self.emb(indices, offsets) |
| |
| indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) |
| offsets = torch.tensor([0, 19, 20, 28, 28, 32]) |
| quantized_node = ns.call_module(nnq.EmbeddingBag) |
| example_inputs = (indices, offsets) |
| |
| for dtype in [torch.quint8, torch.quint4x2]: |
| model = M().eval() |
| float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, |
| qscheme=torch.per_channel_affine_float_qparams, |
| ch_axis=0) |
| float_qparams_qconfig = QConfig(activation=default_placeholder_observer, |
| weight=float_qparams_observer) |
| self.checkGraphModeFxOp( |
| model, |
| example_inputs, |
| QuantType.DYNAMIC, |
| quantized_node, |
| custom_qconfig_dict={"": float_qparams_qconfig} |
| ) |
| |
| # check it works in None and static qconfig |
| for qconfig in [None, default_qconfig]: |
| qconfig_dict = {"": default_qconfig} |
| m = M().eval() |
| m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| self.checkGraphModuleNodes(m, expected_node_occurrence={ |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 0 |
| }) |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) |
| # make sure it runs |
| m(*example_inputs) |
| |
| def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): |
| options = itertools.product(qconfigs, module_type_strs) |
| for qconfig, module_type_str in options: |
| model_eager = M(module_type_str).eval() |
| model_graph = copy.deepcopy(model_eager) |
| if torch.backends.quantized.engine == 'qnnpack' and \ |
| qconfig is float16_dynamic_qconfig: |
| continue |
| # fp16 dynamic quant is not supported for qnnpack |
| |
| eager_qconfig_dict = {x : qconfig for x in module_types} |
| model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) |
| |
| graph_qconfig_dict = { |
| "object_type": [ |
| (x, qconfig) for x in module_types |
| ] |
| } |
| model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,)) |
| model_graph = convert_fx(model_graph) |
| self.assertEqual(model_eager(sample_input), model_graph(sample_input)) |
| self.checkScriptable(model_graph, [[sample_input]], True) |
| |
| @override_qengines |
| def test_rnn_cell(self): |
| if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): |
| return |
| qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] |
| module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] |
| module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] |
| sample_input = torch.tensor([[100, -155], |
| [-155, 100], |
| [100, -155]], dtype=torch.float) |
| self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) |
| |
| @override_qengines |
| def test_rnn(self): |
| if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): |
| return |
| qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] |
| module_type_strs = ['LSTM', 'GRU'] |
| module_types = [torch.nn.LSTM, torch.nn.GRU] |
| niter = 10 |
| sample_input = torch.tensor([[100, -155], |
| [-155, 100], |
| [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) |
| self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) |
| |
| def _test_conv_transpose_impl( |
| self, float_cls: Callable, q_cls: Callable, data: torch.Tensor): |
| with override_quantized_engine('qnnpack'): |
| # Create fp32 versions of FX and Eager models |
| m1 = torch.nn.Sequential(float_cls(1, 1, 1)) |
| m2 = torch.nn.Sequential(float_cls(1, 1, 1)) |
| m2.load_state_dict(m1.state_dict()) |
| m2 = torch.ao.quantization.QuantWrapper(m2) |
| # FX graph |
| result_dict = self.checkGraphModeFxOp( |
| m1, (data,), QuantType.STATIC, |
| expected_node_occurrence={ |
| ns.call_module(q_cls): 1, |
| }) |
| q_result1 = result_dict["quantized_output"] |
| # Eager |
| m2.qconfig = get_default_qconfig(torch.backends.quantized.engine) |
| m2.eval() |
| m2p = torch.ao.quantization.prepare(m2) |
| m2p(data) |
| m2q = torch.ao.quantization.convert(m2p) |
| q_result2 = m2q(data) |
| # verify results match |
| self.assertEqual(q_result1, q_result2) |
| |
| @unittest.skipUnless('qnnpack' in supported_qengines, |
| "This Pytorch Build has not been built with or does not support QNNPACK") |
| def test_conv_transpose_1d(self): |
| self._test_conv_transpose_impl( |
| torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4)) |
| |
| @unittest.skipUnless('qnnpack' in supported_qengines, |
| "This Pytorch Build has not been built with or does not support QNNPACK") |
| def test_conv_transpose_2d(self): |
| self._test_conv_transpose_impl( |
| torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) |
| |
| def test_reshape_fp16(self): |
| class M(torch.nn.Module): |
| def __init__(self, w, b): |
| super().__init__() |
| self.w = w |
| self.b = b |
| |
| def forward(self, x): |
| x = torch.nn.functional.linear(x, self.w) |
| x = x.reshape(-1, 4) |
| x = torch.nn.functional.linear(x, self.w) |
| return x |
| |
| w = torch.randn(4, 4) |
| b = torch.randn(4) |
| m = M(w, b).eval() |
| qconfig_dict = { |
| # reshape will be quantized to fp16 as requested by this qconfig |
| "": float16_static_qconfig, |
| "object_type": [ |
| (torch.nn.functional.linear, default_qconfig) |
| ] |
| } |
| backend_config = get_test_only_legacy_native_backend_config() |
| example_inputs = (torch.randn(1, 4),) |
| m = prepare_fx( |
| m, qconfig_dict, example_inputs=example_inputs, |
| backend_config=backend_config) |
| expected_occurrence = { |
| # input and weight of first and second linear, output of first and second linear |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 6, |
| # we insert placeholder observer for both input and output of reshape |
| ns.call_module(torch.ao.quantization.PlaceholderObserver): 2 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence |
| ) |
| m = convert_fx(m, backend_config=backend_config) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| # dequantize after first linear, before reshape and before output |
| ns.call_method("dequantize"): 3, |
| # before reshape, to(fp16) |
| ns.call_method("to"): 1, |
| ns.call_function(torch.ops.quantized.linear): 2 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence |
| ) |
| # make sure it runs |
| m(torch.randn(2, 4)) |
| |
| def test_multiple_qconfigs_for_single_value(self): |
| """ Test multiple qconfigs for a single value""" |
| class M(torch.nn.Module): |
| def __init__(self, w, b): |
| super().__init__() |
| self.w = w |
| self.b = b |
| |
| def forward(self, x): |
| x = torch.nn.functional.linear(x, self.w) |
| x = torch.sigmoid(x) |
| return x |
| |
| w = torch.randn(4, 4) |
| b = torch.randn(4) |
| m = M(w, b).eval() |
| # TODO: use get_default_qconfig_mapping once it handles fp16 |
| qconfig_mapping = QConfigMapping() \ |
| .set_global(float16_static_qconfig) \ |
| .set_object_type(torch.nn.functional.linear, default_qconfig) |
| example_inputs = (torch.randn(1, 4),) |
| backend_config = get_test_only_legacy_native_backend_config() |
| m = prepare_fx( |
| m, qconfig_mapping, example_inputs=example_inputs, |
| backend_config=backend_config) |
| expected_occurrence = { |
| # input and weight of linear, output of linear |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 3, |
| # input and output of sigmoid |
| ns.call_module(torch.ao.quantization.PlaceholderObserver): 2, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence |
| ) |
| # make sure it runs |
| m = convert_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method("dequantize"): 3, |
| ns.call_method("to"): 2 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence |
| ) |
| |
| def test_boolean_tensor(self): |
| """ Make sure we don't insert observer for boolean Tensors """ |
| class M(torch.nn.Module): |
| def forward(self, x, mask): |
| mask = mask.unsqueeze(0) |
| mask = mask.unsqueeze(1) |
| x = x.masked_fill(mask, 1) |
| return x |
| |
| m = M().eval() |
| example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool()) |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| expected_occurrence = { |
| ns.call_module(torch.ao.quantization.MinMaxObserver): 0 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence) |
| m = convert_fx(m) |
| m(*example_inputs) |
| |
| def test_chunk(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| x, y = torch.chunk(x, 2) |
| x = x + y |
| return x |
| m = M().eval() |
| example_inputs = (torch.rand(2, 2, 2, 2),) |
| m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) |
| m(*example_inputs) |
| m = convert_fx(m) |
| m(*example_inputs) |
| # make sure everything runs |
| |
| def test_ref_pattern_multi_use(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| self.linear1 = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| y = self.linear(x) |
| z = self.linear1(x) |
| a = torch.mul(z, 5) |
| b = torch.add(z, 5) |
| return (y, a, b) |
| |
| m = M().eval() |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (torch.nn.Linear, get_default_qconfig("fbgemm")), |
| (torch.nn.ReLU, get_default_qconfig("fbgemm")), |
| ], |
| } |
| example_inputs = (torch.randn(1, 5),) |
| m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| m = convert_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_module(nnq.Linear): 2, |
| ns.call_method("dequantize"): 2, |
| ns.call_function(torch.add): 1, |
| ns.call_function(torch.mul): 1, |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_occurrence=expected_occurrence) |
| |
| def test_qmatmul(self): |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| z = torch.matmul(x, y) |
| return z |
| |
| m = M().eval() |
| example_inputs = (torch.randn(2, 2), torch.randn(2, 2)) |
| qconfig_dict = get_default_qconfig_mapping("fbgemm") |
| mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) |
| mp(*example_inputs) |
| mq = convert_fx(mp) |
| expected_occurrence = { |
| ns.call_function(torch.matmul): 0, |
| ns.call_function(torch.ops.quantized.matmul): 1, |
| } |
| self.checkGraphModuleNodes( |
| mq, |
| expected_node_occurrence=expected_occurrence) |
| # verify no crash |
| res = mq(*example_inputs) |
| |
| def test_pixel_shuffle(self): |
| class MyBias(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bias = nn.Parameter(torch.randn(8)) |
| |
| class MyModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(8, 8, 1, bias=False) |
| self.bias = MyBias() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = nn.functional.pixel_shuffle(x, 2) |
| x = x.view(-1, 8, 2, 2) |
| bias = self.bias.bias |
| return x + bias |
| |
| backend_config = get_qnnpack_backend_config() |
| qconfig_mapping = get_default_qconfig_mapping("qnnpack") |
| model = MyModel() |
| m = prepare_fx( |
| model, |
| qconfig_mapping=qconfig_mapping, |
| example_inputs=(torch.randn(1, 8, 3, 3),), |
| backend_config=backend_config |
| ) |
| m = convert_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) |
| |
| def test_pixel_shuffle_module(self) -> None: |
| class MyBias(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bias = nn.Parameter(torch.randn(8)) |
| |
| class MyModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(8, 8, 1, bias=False) |
| self.ps = nn.PixelShuffle(upscale_factor=2) |
| self.bias = MyBias() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.ps(x) |
| x = x.view(-1, 8, 2, 2) |
| bias = self.bias.bias |
| return x + bias |
| |
| backend_config = get_qnnpack_backend_config() |
| qconfig_mapping = get_default_qconfig_mapping("qnnpack") |
| model = MyModel() |
| m = prepare_fx( |
| model, |
| qconfig_mapping=qconfig_mapping, |
| example_inputs=(torch.randn(1, 8, 3, 3),), |
| backend_config=backend_config |
| ) |
| m = convert_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 1, |
| ns.call_module(nn.PixelShuffle): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) |
| |
| def test_pixel_unshuffle(self): |
| class MyBias(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bias = nn.Parameter(torch.randn(64)) |
| |
| class MyModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(8, 8, 1, bias=False) |
| self.bias = MyBias() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = nn.functional.pixel_unshuffle(x, 2) |
| bias = self.bias.bias |
| return x + bias |
| |
| for backend in ["fbgemm", "qnnpack"]: |
| if backend == "fbgemm": |
| backend_config = get_fbgemm_backend_config() |
| else: |
| backend_config = get_qnnpack_backend_config() |
| qconfig_mapping = get_default_qconfig_mapping(backend) |
| model = MyModel() |
| m = prepare_fx( |
| model, |
| qconfig_mapping=qconfig_mapping, |
| example_inputs=(torch.randn(1, 8, 6, 6),), |
| backend_config=backend_config |
| ) |
| m = convert_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) |
| |
| def test_pixel_unshuffle_module(self) -> None: |
| class MyBias(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bias = nn.Parameter(torch.randn(64)) |
| |
| class MyModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(8, 8, 1, bias=False) |
| self.unshuffle = nn.PixelUnshuffle(downscale_factor=2) |
| self.bias = MyBias() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.unshuffle(x) |
| bias = self.bias.bias |
| return x + bias |
| |
| for backend in ["fbgemm", "qnnpack"]: |
| if backend == "fbgemm": |
| backend_config = get_fbgemm_backend_config() |
| else: |
| backend_config = get_qnnpack_backend_config() |
| qconfig_mapping = get_default_qconfig_mapping(backend) |
| model = MyModel() |
| m = prepare_fx( |
| model, |
| qconfig_mapping=qconfig_mapping, |
| example_inputs=(torch.randn(1, 8, 6, 6),), |
| backend_config=backend_config |
| ) |
| m = convert_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 1, |
| ns.call_module(nn.PixelUnshuffle): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) |
| |
| |
| |
| def test_narrow(self): |
| class MyBias(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bias = nn.Parameter(torch.randn(4)) |
| |
| class MyModel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(8, 8, 1, bias=False) |
| self.bias = MyBias() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = torch.narrow(x, 1, 0, 4) |
| bias = self.bias.bias |
| return x + bias |
| |
| for backend in ["fbgemm", "qnnpack"]: |
| if backend == "fbgemm": |
| backend_config = get_fbgemm_backend_config() |
| else: |
| backend_config = get_qnnpack_backend_config() |
| qconfig_mapping = get_default_qconfig_mapping(backend) |
| model = MyModel() |
| m = prepare_fx( |
| model, |
| qconfig_mapping=qconfig_mapping, |
| example_inputs=(torch.randn(1, 8, 3, 3),), |
| backend_config=backend_config |
| ) |
| m = convert_fx(m) |
| expected_occurrence = { |
| ns.call_function(torch.quantize_per_tensor): 2, |
| ns.call_method("dequantize"): 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) |
| |
| class TestQuantizeFxModels(QuantizationTestCase): |
| @skipIfNoFBGEMM |
| @unittest.skipIf(not TEST_CUDA, "gpu is not available.") |
| def test_static_gpu_convert_basic(self): |
| |
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu1 = nn.ReLU() |
| self.conv1 = nn.Conv2d(1, 6, 5) |
| self.linear1 = nn.Linear(120, 1) |
| |
| def forward(self, x): |
| x = self.relu1(self.conv1(x)) |
| y = self.linear1(x.view(-1)) |
| return y |
| |
| input = torch.randn((5, 1, 6, 6)).to('cuda') |
| example_inputs = (input,) |
| model = Net().to('cuda').eval() |
| qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} |
| model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) |
| model_prepared(*example_inputs) |
| model_quantized = convert_to_reference_fx(model_prepared) |
| out = model_quantized(*example_inputs) |
| self.assertEqual(out.device.type, 'cuda') |
| |
| @skipIfNoFBGEMM |
| @unittest.skipIf(not TEST_CUDA, "gpu is not available.") |
| def test_switch_device_prepare_convert(self): |
| |
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu1 = nn.ReLU() |
| self.conv1 = nn.Conv2d(1, 6, 5) |
| self.linear1 = nn.Linear(120, 1) |
| |
| def forward(self, x): |
| x = self.relu1(self.conv1(x)) |
| y = self.linear1(x.view(-1)) |
| return y |
| |
| for device in ['cuda', 'cpu']: |
| device_after = 'cuda' if device == 'cpu' else 'cpu' |
| input = torch.randn((5, 1, 6, 6)).to(device) |
| model = Net().to(device).eval() |
| qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} |
| model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,)) |
| model_prepared(input) |
| model_prepared.to(device_after) |
| model_quantized = convert_to_reference_fx(model_prepared) |
| out = model_quantized(input.to(device_after)) |
| self.assertEqual(out.device.type, device_after) |
| |
| @skipIfNoFBGEMM |
| @unittest.skipIf(not TEST_CUDA, "gpu is not available.") |
| def test_prepare_serialize_switch_device_convert(self): |
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 6, 5) |
| self.linear1 = nn.Linear(120, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| y = self.linear1(x.view(-1)) |
| return y |
| |
| for device in ['cuda', 'cpu']: |
| for device_after in ['cuda', 'cpu']: |
| input = torch.randn((5, 1, 6, 6)).to(device) |
| model = Net().to(device).eval() |
| qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} |
| model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,)) |
| model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,)) |
| model_prepared_first(input) |
| state_dict = model_prepared_first.state_dict() |
| del model_prepared_first |
| model_prepared_second.load_state_dict(state_dict) |
| model_prepared_second.to(device_after) |
| model_quantized = convert_to_reference_fx(model_prepared_second) |
| out = model_quantized(input.to(device_after)) |
| self.assertEqual(out.device.type, device_after) |
| |
| @skip_if_no_torchvision |
| def test_model_dropout(self): |
| from torchvision import models |
| m = models.mobilenet_v3_small() |
| qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('fbgemm') |
| example_inputs = (torch.randn(1, 3, 224, 224),) |
| mp = prepare_qat_fx(m, qconfig_mapping, example_inputs=example_inputs) |
| mp(*example_inputs) |
| with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext(): |
| mq = convert_fx(mp) |
| mq(*example_inputs) |
| |
| def _test_model_impl( |
| self, mode, name, model, eager_quantizable_model, |
| check_with_eager=True, |
| diff_of_quant=None, |
| diff_from_eager=None): |
| if diff_of_quant is None or diff_from_eager is None: |
| diff_of_quant = {} |
| diff_from_eager = {} |
| |
| if mode not in diff_of_quant or mode not in diff_from_eager: |
| diff_of_quant[mode] = {} |
| diff_from_eager[mode] = {} |
| |
| input_tensor = torch.rand(1, 3, 224, 224) |
| input_tensor_inception = torch.rand(1, 3, 299, 299) |
| output_value = torch.randint(0, 1, (1,)) |
| |
| # print('quantizing:', name, ' mode:', mode) |
| if name == 'inception_v3': |
| input_value = input_tensor_inception |
| else: |
| input_value = input_tensor |
| |
| qconfig = default_qconfig if mode == 'static' else default_qat_qconfig |
| qconfig_dict = {'': qconfig} |
| script = torch.jit.script(model) |
| |
| # make sure graph module and script module are both runanble |
| original_out = model(input_value) |
| is_not_tuple_out = not isinstance(original_out, tuple) |
| script_out = script(input_value) |
| |
| # set to train just before quantization |
| prepare_fx_fn = prepare_fx |
| if mode != 'static': |
| model.train() |
| prepare_fx_fn = prepare_qat_fx |
| |
| prepared = prepare_fx_fn(model, qconfig_dict) |
| |
| if mode == 'ddp': |
| mp.spawn(run_ddp, |
| args=(world_size, prepared), |
| nprocs=world_size, |
| join=True) |
| elif mode == 'qat': |
| assert prepared.training, 'prepared must be in training mode for qat' |
| optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) |
| criterion = nn.CrossEntropyLoss() |
| train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) |
| else: |
| for i in range(10): |
| prepared(input_value) |
| |
| # print('after observation root:', prepared.root) |
| |
| qgraph = convert_fx(prepared) |
| # print('after quantization root:', qgraph.root) |
| # print('after quantization code:', qgraph.src) |
| qgraph.eval() |
| qgraph_script = torch.jit.script(qgraph) |
| # print('quantized and scripted:', qgraph_script.graph) |
| |
| qgraph_out = qgraph(input_value) |
| qgraph_script = qgraph_script(input_value) |
| |
| if is_not_tuple_out: |
| diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() |
| assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' |
| else: |
| print('tuple output') |
| |
| if eager_quantizable_model is not None: |
| # comparing to eager mode quantization |
| qeager = eager_quantizable_model |
| ref_out = qeager(input_value) |
| qeager.qconfig = qconfig |
| if mode == 'static': |
| qeager.fuse_model() |
| prepare(qeager, inplace=True) |
| else: |
| qeager.train() |
| qeager.fuse_model() |
| prepare_qat(qeager, inplace=True) |
| |
| # calibration |
| if mode == 'ddp': |
| mp.spawn(run_ddp, |
| args=(world_size, qeager), |
| nprocs=world_size, |
| join=True) |
| elif mode == 'qat': |
| assert qeager.training, 'qeager should be in training mode for qat' |
| optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) |
| train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) |
| else: |
| for i in range(10): |
| qeager(input_value) |
| |
| # print('ref after observation:', qeager) |
| |
| convert(qeager, inplace=True) |
| qeager.eval() |
| |
| # print('ref after quantization:', qeager) |
| qeager_out = qeager(input_value) |
| qeager_script = torch.jit.script(qeager) |
| qscript_out = qeager_script(input_value) |
| if is_not_tuple_out: |
| diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() |
| if check_with_eager: |
| self.assertEqual(diff_from_eager[mode][name], 0, |
| 'Result of graph mode quantization and ' + |
| 'eager mode quantization on model: ' + name + |
| ' should match. Mode: ' + mode + |
| ' diff:' + str(diff_from_eager[mode][name])) |
| |
| def _test_building_block(self, quant_type, BB): |
| eager = BB().float() |
| graph = copy.deepcopy(eager) |
| |
| if quant_type == QuantType.STATIC: |
| qconfig = default_qconfig |
| eager_prepare = prepare |
| graph_prepare = prepare_fx |
| eager.eval() |
| graph.eval() |
| calibrate_or_train = test_only_eval_fn |
| data = self.img_data_2d |
| is_qat = False |
| else: |
| assert quant_type == QuantType.QAT |
| qconfig = default_qat_qconfig |
| eager_prepare = prepare_qat |
| graph_prepare = prepare_qat_fx |
| eager.train() |
| graph.train() |
| calibrate_or_train = test_only_train_fn |
| data = self.img_data_2d_train |
| is_qat = True |
| |
| if hasattr(eager, "fuse_model"): |
| eager.fuse_model() |
| eager = QuantWrapper(eager) |
| eager.qconfig = qconfig |
| eager = eager_prepare(eager) |
| |
| qconfig_dict = {"": qconfig} |
| graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],)) |
| |
| eager_out = eager(data[0][0]) |
| graph_out = graph(data[0][0]) |
| # Eager Mode and FX Graph Mode QAT now differ in numerics both |
| # in Post Training and QAT because FX Graph Mode uses same fake_quant instances |
| # for input and output of CopyNode |
| # self.assertEqual(eager_out, graph_out) |
| |
| calibrate_or_train(eager, data) |
| calibrate_or_train(graph, data) |
| |
| eager = convert(eager) |
| graph = convert_fx(graph) |
| |
| eager_out = eager(data[0][0]) |
| graph_out = graph(data[0][0]) |
| |
| @override_qengines |
| def test_resnet_base(self): |
| models = [ResNetBase] |
| options = itertools.product(self.static_quant_types, models) |
| for quant_type, M in options: |
| self._test_building_block(quant_type, M) |
| |
| @skip_if_no_torchvision |
| @skipIfNoFBGEMM |
| @unittest.skip("skip for now since tbb failed") |
| def test_torchvision(self): |
| from torchvision import models |
| from torchvision.models import quantization as quantized_models |
| from torchvision.models.quantization.utils import _replace_relu |
| |
| def get_available_classification_models(models): |
| return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] |
| |
| model_list = get_available_classification_models(models) |
| quantized_model_list = get_available_classification_models(quantized_models) |
| |
| quantized_model_list = set(quantized_model_list) |
| # test eager and graph consistency |
| model_list = quantized_model_list |
| # mobilenet/inception_v3/googlenet qat is not working due to AdaptiveAveragePool qat |
| # we might observe the output of AdaptiveAveragePool in the future |
| # and re-enable the test |
| fx_eager_not_matching = [ |
| ("mobilenet_v2", "qat"), |
| ("inception_v3", "qat"), |
| ("googlenet", "qat") |
| ] # because relu6 is replaced as relu in mobilenetv2 |
| |
| diff_of_quant = {} |
| diff_from_eager = {} |
| modes = ['static', 'qat'] |
| options = itertools.product(modes, model_list) |
| for mode, name in options: |
| pretrained = name in quantized_model_list # load pretrained model to compare with quantized model |
| kwargs = {} |
| # turn off transform input for inception_v3 since |
| # it's not quantized in eager mode and in fx graph |
| # mode we can't skip quantizing a method right now |
| # (might be supported in the future) |
| if name in ["inception_v3", "googlenet"]: |
| kwargs["transform_input"] = False |
| eager_quantizable_model = None |
| if name in quantized_model_list: |
| eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False, **kwargs).eval().float() |
| # compare with eager mode quantized model when it is available |
| pretrained = eager_quantizable_model is not None |
| model = models.__dict__[name](pretrained=pretrained, **kwargs).eval().float() |
| if name == "mobilenet_v2": |
| _replace_relu(model) |
| # disable aux logits |
| if hasattr(model, "aux_logits"): |
| model.aux_logits = False |
| model.AuxLogits = None |
| if eager_quantizable_model: |
| eager_quantizable_model.aux_logits = False |
| eager_quantizable_model.AuxLogits = None |
| |
| check_with_eager = (name, mode) not in fx_eager_not_matching |
| self._test_model_impl( |
| mode, name, model, eager_quantizable_model, |
| check_with_eager, |
| diff_of_quant, diff_from_eager) |
| |
| def print_diffs(diffs): |
| for mode, diffs_for_mode in diffs.items(): |
| print('mode:', mode) |
| for name, diff in diffs_for_mode.items(): |
| print(name, ':', diff) |
| |
| # print('differences between float and quantized') |
| # print_diffs(diff_of_quant) |
| # print('----------------------') |
| # print('differences between graph mode and eager mode') |
| # print_diffs(diff_from_eager) |
| # print('----------------------') |
| |
| @skip_if_no_torchvision |
| @skipIfNoFBGEMM |
| @unittest.skip("TODO: Test is always failing - https://github.com/pytorch/pytorch/issues/54979") |
| def test_resnet18_ddp(self): |
| from torchvision import models |
| from torchvision.models import quantization as quantized_models |
| eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float() |
| model = models.__dict__[name](pretrained=False).eval().float() |
| self._test_model_impl( |
| 'ddp', 'resnet18', model, eager_quantizable_model) |
| |
| @override_qengines |
| def test_qat_embeddingbag_linear(self): |
| for device in get_supported_device_types(): |
| class EmbeddingBagLinear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum') |
| self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float) |
| |
| def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, |
| per_sample_weights: Optional[torch.Tensor] = None): |
| x = self.emb(input, offsets, per_sample_weights) |
| x = self.linear(x) |
| return x |
| |
| qengine = torch.backends.quantized.engine |
| qconfig_dict = QConfigMapping() \ |
| .set_global(get_default_qat_qconfig(qengine)) \ |
| .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig) |
| |
| train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] |
| eval_output = [[torch.randint(0, 10, (12, 1))]] |
| |
| model = EmbeddingBagLinear().train() |
| prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) |
| test_only_train_fn(prepared_fx_model, train_indices) |
| quant_model = convert_fx(prepared_fx_model, |
| qconfig_mapping=qconfig_dict) |
| |
| def checkQuantized(model): |
| # Make sure EmbeddingBag is now a quantized EmbeddingBag. |
| self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) |
| # Also test that Linear has been quantized. |
| self.assertTrue(type(model.linear), nnq.Linear) |
| |
| test_only_eval_fn(model, eval_output) |
| self.checkScriptable(model, eval_output) |
| self.checkNoQconfig(model) |
| checkQuantized(quant_model) |
| |
| |
| @override_qengines |
| def test_qat_embedding_linear(self): |
| for device in get_supported_device_types(): |
| class EmbeddingLinear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) |
| self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float) |
| |
| def forward(self, input: torch.Tensor): |
| x = torch.sum(self.emb(input), dim=1) |
| x = self.linear(x) |
| return x |
| |
| qengine = torch.backends.quantized.engine |
| qconfig_dict = {"": get_default_qat_qconfig(qengine), |
| "object_type": [(torch.nn.Embedding, default_embedding_qat_qconfig)]} |
| |
| |
| train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] |
| eval_output = [[torch.randint(0, 10, (12, 1))]] |
| |
| model = EmbeddingLinear().train() |
| prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) |
| test_only_train_fn(prepared_fx_model, train_indices) |
| quant_model = convert_fx(prepared_fx_model, |
| qconfig_mapping=qconfig_dict) |
| |
| def checkQuantized(model): |
| # Make sure EmbeddingBag is now a quantized EmbeddingBag. |
| self.assertTrue(type(model.emb), nn.quantized.Embedding) |
| # Also test that Linear has been quantized. |
| self.assertTrue(type(model.linear), nnq.Linear) |
| |
| test_only_eval_fn(model, eval_output) |
| self.checkScriptable(model, eval_output) |
| self.checkNoQconfig(model) |
| checkQuantized(quant_model) |
| |
| @given( |
| device=st.sampled_from( |
| ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
| ) |
| ) |
| @settings(deadline=None) |
| @override_qengines |
| def test_qat_functional_linear(self, device): |
| if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): |
| return |
| |
| class Linear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.ones(5, 5) |
| self.b = torch.zeros(5) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.w, self.b) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods1 = torch.nn.Sequential(Linear(), Linear()) |
| self.mods2 = Linear() |
| |
| def forward(self, x): |
| x = self.mods1(x) |
| x = self.mods2(x) |
| return x |
| |
| model = M().train() |
| ref_fake_quant = FakeQuantize.with_args( |
| observer=MovingAverageMinMaxObserver, |
| quant_min=0, |
| quant_max=255, |
| dtype=torch.quint8, |
| reduce_range=False, |
| ) |
| ref_weight_fake_quant = FakeQuantize.with_args( |
| observer=MovingAverageMinMaxObserver, |
| quant_min=-128, |
| quant_max=127, |
| dtype=torch.qint8, |
| reduce_range=False, |
| ) |
| ref_qat_qconfig = QConfig( |
| activation=ref_fake_quant, weight=ref_weight_fake_quant |
| ) |
| qconfig_dict = {"": ref_qat_qconfig} |
| example_inputs = (torch.randn(1, 5),) |
| prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) |
| |
| custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( |
| observer=MovingAverageMinMaxObserver, |
| quant_min=0, |
| quant_max=255, |
| dtype=torch.quint8, |
| reduce_range=False, |
| ) |
| custom_weight_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( |
| observer=MovingAverageMinMaxObserver, |
| quant_min=-128, |
| quant_max=127, |
| dtype=torch.qint8, |
| reduce_range=False, |
| ) |
| custom_qconfig = QConfig( |
| activation=custom_fake_quant, weight=custom_weight_fake_quant |
| ) |
| custom_qconfig_dict = {"": custom_qconfig} |
| prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs) |
| |
| prepared.to(device) |
| prepared_ref.to(device) |
| |
| prepared.apply(torch.ao.quantization.disable_fake_quant) |
| prepared.apply(torch.ao.quantization.disable_observer) |
| prepared_ref.apply(torch.ao.quantization.disable_fake_quant) |
| prepared_ref.apply(torch.ao.quantization.disable_observer) |
| |
| inp = torch.randn(5, 5, device=device, requires_grad=True) |
| for i in range(10): |
| if i == 2: |
| prepared.apply(torch.ao.quantization.enable_observer) |
| prepared_ref.apply(torch.ao.quantization.enable_observer) |
| if i == 4: |
| prepared.apply(torch.ao.quantization.enable_fake_quant) |
| prepared_ref.apply(torch.ao.quantization.enable_fake_quant) |
| |
| inp = torch.randn(5, 5, device=device, requires_grad=True) |
| out_ref = prepared_ref(inp) |
| out = prepared(inp) |
| torch.testing.assert_close(out, out_ref) |
| |
| # try backward pass |
| labels = torch.randn(5, 5, device=device) |
| loss = (out - labels).sum() |
| grad = torch.autograd.grad(loss, [inp]) |
| loss_ref = (out_ref - labels).sum() |
| grad_ref = torch.autograd.grad(loss_ref, [inp]) |
| torch.testing.assert_close(grad[0], grad_ref[0]) |
| |
| if 'fbgemm' in torch.backends.quantized.supported_engines: |
| # During the lowering step in convert, fold_weight calls quantized::linear_prepack |
| # which doesn't support QuantizedCuda backend |
| prepared.cpu() |
| prepared_ref.cpu() |
| converted = convert_fx(prepared) |
| converted_ref = convert_fx(prepared_ref) |
| inp = torch.rand(5, 5) |
| out = converted(inp) |
| out_ref = converted_ref(inp) |
| |
| torch.testing.assert_close(out, out_ref) |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_quantization.py TESTNAME\n\n" |
| "instead.") |