| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| import torch.nn.quantized as nnq |
| import torch.nn.quantized.dynamic as nnqd |
| import torch.nn.intrinsic as nni |
| import torch.nn.intrinsic.quantized as nniq |
| import torch.multiprocessing as mp |
| |
| # graph mode quantization based on fx |
| from torch.quantization.quantize_fx import ( |
| prepare_fx, |
| convert_fx, |
| prepare_qat_fx, |
| ) |
| |
| from torch.quantization.fx.pattern_utils import ( |
| is_match, |
| MatchAllNode, |
| ) |
| |
| from torch.quantization import ( |
| QuantType, |
| QuantStub, |
| DeQuantStub, |
| QuantWrapper, |
| quant_type_to_str, |
| default_qconfig, |
| default_dynamic_qconfig, |
| default_qat_qconfig, |
| per_channel_dynamic_qconfig, |
| float16_dynamic_qconfig, |
| float_qparams_weight_only_qconfig, |
| get_default_qconfig, |
| get_default_qat_qconfig, |
| fuse_modules, |
| prepare, |
| prepare_qat, |
| convert, |
| quantize_dynamic, |
| default_placeholder_observer, |
| PerChannelMinMaxObserver, |
| QConfigDynamic, |
| FixedQParamsFakeQuantize, |
| ) |
| |
| # test utils |
| from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA |
| from torch.testing._internal.common_quantization import ( |
| QuantizationTestCase, |
| skipIfNoFBGEMM, |
| skip_if_no_torchvision, |
| train_one_epoch, |
| run_ddp, |
| test_only_eval_fn, |
| test_only_train_fn, |
| ) |
| |
| from torch.testing._internal.common_quantization import ( |
| LinearModelWithSubmodule, |
| ResNetBase, |
| RNNDynamicModel, |
| RNNCellDynamicModel, |
| ) |
| |
| from torch.testing._internal.common_quantized import ( |
| override_qengines, |
| ) |
| |
| from torch.testing._internal.common_distributed import skip_if_not_multigpu |
| |
| from torch.testing._internal.common_quantization import NodeSpec as ns |
| |
| from torch.testing import FileCheck |
| |
| import copy |
| import itertools |
| import operator |
| import unittest |
| import io |
| |
| class TestFuseFx(QuantizationTestCase): |
| def test_fuse_conv_bn_relu(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1d = nn.Conv1d(1, 1, 1) |
| self.conv2d = nn.Conv2d(1, 1, 1) |
| self.conv3d = nn.Conv3d(1, 1, 1) |
| self.bn1d = nn.BatchNorm1d(1) |
| self.bn2d = nn.BatchNorm2d(1) |
| self.bn3d = nn.BatchNorm3d(1) |
| self.conv1d2 = nn.Conv1d(1, 1, 1) |
| self.conv2d2 = nn.Conv2d(1, 1, 1) |
| self.conv3d2 = nn.Conv3d(1, 1, 1) |
| self.bn1d2 = nn.BatchNorm1d(1) |
| self.bn2d2 = nn.BatchNorm2d(1) |
| self.bn3d2 = nn.BatchNorm3d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1d(x) |
| x = self.bn1d(x) |
| x = self.conv2d(x) |
| x = self.bn2d(x) |
| x = self.conv3d(x) |
| x = self.bn3d(x) |
| x = self.conv1d2(x) |
| x = self.bn1d2(x) |
| x = self.relu(x) |
| x = self.conv2d2(x) |
| x = self.bn2d2(x) |
| x = self.relu(x) |
| x = self.conv3d2(x) |
| x = self.bn3d2(x) |
| x = self.relu(x) |
| return x |
| |
| # test train mode |
| m = M().train() |
| # currently we don't check if the module are configured with qconfig before fusion |
| # TODO: if we decide to do that in the future, this test needs to |
| # be updated |
| # train mode fuse_fx is called in prepare_qat_fx |
| m = prepare_qat_fx(m, {}) |
| expected_nodes = [ |
| ns.call_module(nni.ConvBn1d), |
| ns.call_module(nni.ConvBn2d), |
| ns.call_module(nni.ConvBn3d), |
| ns.call_module(nni.ConvBnReLU1d), |
| ns.call_module(nni.ConvBnReLU2d), |
| ns.call_module(nni.ConvBnReLU3d), |
| ] |
| expected_occurrence = { |
| ns.call_module(nn.ReLU): 0 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| # test eval mode |
| m = M().eval() |
| from torch.quantization.quantize_fx import fuse_fx |
| # fuse_fx is a top level api and only supports eval mode |
| m = fuse_fx(m) |
| expected_nodes = [ |
| ns.call_module(nn.Conv1d), |
| ns.call_module(nn.Conv2d), |
| ns.call_module(nn.Conv3d), |
| ns.call_module(nni.ConvReLU1d), |
| ns.call_module(nni.ConvReLU2d), |
| ns.call_module(nni.ConvReLU3d), |
| ] |
| # ConvBnRelu1d is not fused |
| expected_occurrence = { |
| ns.call_module(nn.ReLU): 0 |
| } |
| self.checkGraphModuleNodes( |
| m, |
| expected_node_list=expected_nodes, |
| expected_node_occurrence=expected_occurrence) |
| |
| def test_fuse_module_relu(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1d = nn.Conv1d(1, 1, 1) |
| self.conv2d = nn.Conv2d(1, 1, 1) |
| self.conv3d = nn.Conv3d(1, 1, 1) |
| self.bn1d = nn.BatchNorm1d(1) |
| self.bn2d = nn.BatchNorm2d(1) |
| self.bn3d = nn.BatchNorm3d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1d(x) |
| x = self.relu(x) |
| x = self.conv2d(x) |
| x = self.relu(x) |
| x = self.conv3d(x) |
| x = self.relu(x) |
| x = self.bn1d(x) |
| x = self.relu(x) |
| x = self.bn2d(x) |
| x = self.relu(x) |
| x = self.bn3d(x) |
| x = self.relu(x) |
| return x |
| |
| m = M().eval() |
| from torch.quantization.quantize_fx import fuse_fx |
| m = fuse_fx(m) |
| expected_nodes = [ |
| ns.call_module(nni.ConvReLU1d), |
| ns.call_module(nni.ConvReLU2d), |
| ns.call_module(nni.ConvReLU3d), |
| ns.call_module(nni.BNReLU2d), |
| ns.call_module(nni.BNReLU3d), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) |
| |
| @skipIfNoFBGEMM |
| class TestQuantizeFx(QuantizationTestCase): |
| def test_pattern_match(self): |
| """ test MatchAllNode with |
| conv - bn - add - relu pattern |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Conv2d(1, 1, 1) |
| self.bn = nn.BatchNorm2d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x, y): |
| x = self.conv(x) |
| x = self.bn(x) |
| x = x + y |
| x = self.relu(x) |
| return x |
| |
| pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) |
| m = torch.fx.symbolic_trace(M()) |
| modules = dict(m.named_modules()) |
| for n in m.graph.nodes: |
| if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: |
| self.assertTrue(is_match(modules, n, pattern)) |
| |
| def _get_conv_linear_test_cases(self): |
| ''' Returns a list of test cases, with format: |
| is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_op |
| ''' |
| class Conv(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| self.stride = (1, 1) |
| self.padding = (0, 0) |
| self.dilation = (1, 1) |
| self.groups = 1 |
| |
| def forward(self, x): |
| return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups) |
| |
| conv_input = torch.rand(1, 3, 224, 224) |
| conv_weight = torch.rand(3, 3, 3, 3) |
| |
| class Linear(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight) |
| |
| linear_input = torch.rand(8, 5) |
| linear_weight = torch.rand(10, 5) |
| |
| class LinearModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| linear_module_input = torch.rand(8, 5) |
| |
| tests = [ |
| (False, Conv, (conv_weight,), (conv_input,), |
| ns.call_function(torch.ops.quantized.conv2d), |
| ns.call_function(torch.ops.quantized.conv2d_prepack)), |
| (True, Linear, (linear_weight,), (linear_input,), |
| ns.call_function(torch.ops.quantized.linear_dynamic), |
| ns.call_function(torch.ops.quantized.linear_prepack)), |
| (False, Linear, (linear_weight,), (linear_input,), |
| ns.call_function(torch.ops.quantized.linear), |
| ns.call_function(torch.ops.quantized.linear_prepack)), |
| (True, LinearModule, (), (linear_module_input,), |
| ns.call_module(nnqd.Linear), |
| None), |
| (False, LinearModule, (), (linear_module_input,), |
| ns.call_module(nnq.Linear), |
| None), |
| ] |
| return tests |
| |
| """ |
| Unit tests for functionalities |
| """ |
| @skipIfNoFBGEMM |
| def test_functional_no_debug(self): |
| """ Test quantizing functional conv and linear |
| """ |
| tests = self._get_conv_linear_test_cases() |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = dict() |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node=quantized_node, |
| expected_node_occurrence=node_occurrence, |
| debug=False) |
| |
| @skipIfNoFBGEMM |
| def test_functional_debug(self): |
| """ Test quantizing functional conv and linear with debug option |
| """ |
| tests = self._get_conv_linear_test_cases() |
| for (is_dynamic, ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC |
| node_occurrence = dict() |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| node_occurrence[quantized_node] = 0 |
| self.checkGraphModeFxOp( |
| ModuleClass(*module_constructor_inputs), |
| inputs, quant_type, |
| expected_node_occurrence=node_occurrence, |
| debug=True) |
| |
| @skipIfNoFBGEMM |
| def test_dynamic_quant_weight_observer(self): |
| ''' Test that weight observer is run in convert step |
| ''' |
| |
| class M(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight) |
| |
| m = M(torch.rand(1, 1)).eval() |
| qconfig = default_dynamic_qconfig |
| qconfig_dict = {'': qconfig} |
| prepared = prepare_fx(m, qconfig_dict) |
| quantized = convert_fx(prepared, debug=True) |
| qparams = (quantized._scale_0, quantized._zero_point_0) |
| weight_obs = qconfig.weight() |
| weight_obs(quantized.weight) |
| ref_qparams = weight_obs.calculate_qparams() |
| self.assertEqual(qparams, ref_qparams) |
| |
| def test_conv_bn_relu(self): |
| convs = { |
| 1: nn.Conv1d, |
| 2: nn.Conv2d, |
| 3: nn.Conv3d, |
| } |
| bns = { |
| 1: nn.BatchNorm1d, |
| 2: nn.BatchNorm2d, |
| 3: nn.BatchNorm3d, |
| } |
| quantized_convs = { |
| 1: nnq.Conv1d, |
| 2: nnq.Conv2d, |
| 3: nnq.Conv3d, |
| } |
| quantized_conv_relus = { |
| 1: nniq.ConvReLU1d, |
| 2: nniq.ConvReLU2d, |
| 3: nniq.ConvReLU3d, |
| } |
| |
| class M(torch.nn.Module): |
| def __init__(self, dim, has_relu): |
| super().__init__() |
| self.conv = convs[dim](3, 3, 3) |
| self.bn = bns[dim](3) |
| self.relu = nn.ReLU() if has_relu else nn.Identity() |
| self.has_relu = has_relu |
| self.quant = QuantStub() |
| self.dequant = DeQuantStub() |
| |
| def forward(self, x): |
| x = self.quant(x) |
| x = self.conv(x) |
| x = self.bn(x) |
| if self.has_relu: |
| x = self.relu(x) |
| x = self.dequant(x) |
| return x |
| |
| options = itertools.product([1, 2], [True, False], self.static_quant_types) |
| for dim, has_relu, quant_type in options: |
| expected_node = ns.call_module( |
| quantized_conv_relus[dim] if has_relu |
| else quantized_convs[dim]) |
| m = M(dim, has_relu) |
| m_eager = copy.deepcopy(m) |
| result = self.checkGraphModeFxOp( |
| m, |
| self.img_data_dict[dim], |
| quant_type, |
| expected_node=expected_node, |
| ) |
| |
| # check numerics |
| qengine = torch.backends.quantized.engine |
| if quant_type == QuantType.STATIC: |
| m_eager.eval() |
| qconfig = get_default_qconfig(qengine) |
| prepare_fn = prepare |
| else: |
| m_eager.train() |
| qconfig = get_default_qat_qconfig(qengine) |
| prepare_fn = prepare_qat |
| |
| fuse_list = ["conv", "bn"] |
| if has_relu: |
| fuse_list.append("relu") |
| fuse_modules(m_eager, fuse_list, inplace=True) |
| m_eager.qconfig = qconfig |
| m_eager = prepare_fn(m_eager) |
| m_eager(*self.img_data_dict[dim][0]) |
| m_eager = convert(m_eager) |
| result_eager = m_eager(*self.img_data_dict[dim][0]) |
| self.assertEqual(result, result_eager) |
| |
| |
| @skipIfNoFBGEMM |
| def test_dynamic_quant_fp16(self): |
| class Linear(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = torch.nn.Parameter(weight) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight) |
| |
| linear_input = torch.rand(8, 5) |
| linear_weight = torch.rand(10, 5) |
| |
| class LinearModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| linear_module_input = torch.rand(8, 5) |
| |
| tests = [ |
| (Linear, (linear_weight,), (linear_input,), |
| ns.call_function(torch.ops.quantized.linear_dynamic), |
| ns.call_function(torch.ops.quantized.linear_prepack_fp16)), |
| (LinearModule, (), (linear_module_input,), |
| ns.call_module(nnqd.Linear), |
| None), |
| ] |
| for (ModuleClass, module_constructor_inputs, |
| inputs, quantized_node, weight_prepack_node) in tests: |
| for debug in [True, False]: |
| node_occurrence = dict() |
| if weight_prepack_node: |
| node_occurrence[weight_prepack_node] = 0 |
| m = ModuleClass(*module_constructor_inputs).eval() |
| qconfig_dict = {"": float16_dynamic_qconfig} |
| m = prepare_fx(m, qconfig_dict) |
| m = convert_fx(m, debug=debug) |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| |
| |
| @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") |
| @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| @override_qengines |
| def test_qat_prepare_device_affinity(self): |
| """ |
| Tests that FX QAT prepare pass respects device affinity |
| """ |
| class Model(nn.Module): |
| |
| def __init__(self): |
| super(Model, self).__init__() |
| self.conv = nn.Conv2d(1, 1, 1) |
| self.bn = nn.BatchNorm2d(1) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| x = self.relu(x) |
| return x |
| |
| model = Model() |
| qengine = torch.backends.quantized.engine |
| qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)} |
| device = torch.device('cuda:0') |
| model.to(device) |
| |
| # QAT prepare |
| model = prepare_qat_fx(model, qconfig_dict) |
| |
| # ensure that running an input on CUDA works without any needed changes |
| input = torch.randn(4, 1, 4, 4, device=device) |
| model(input) |
| |
| # ensure all buffers and parameters are on the device we expect |
| model_devices = {p.device for p in model.parameters()} | \ |
| {p.device for p in model.buffers()} |
| self.assertEqual(len(model_devices), 1) |
| model_device = next(iter(model_devices)) |
| self.assertEqual(model_device, device) |
| |
| @skipIfNoFBGEMM |
| def test_dict_output(self): |
| """ Make sure quantization runs for models with dictionary output |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| return {"output": self.conv(x["input"])} |
| |
| dict_input = {"input": torch.randn(1, 1, 1, 1)} |
| m = M().eval() |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_fx(m, qconfig_dict) |
| m(dict_input) |
| m = convert_fx(m) |
| m(dict_input) |
| |
| @override_qengines |
| def test_attention(self): |
| """ Make sure quantization runs for a corner case in attention module |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| q, k, v = x.chunk(3, dim=0) |
| q = q.contiguous().view(-1, 1).transpose(0, 1) |
| k = k.contiguous().view(-1, 1).transpose(0, 1) |
| v = v.contiguous().view(-1, 1).transpose(0, 1) |
| torch._assert( |
| k.size(1) == 1, "key size should be equal to 1" |
| ) |
| r = torch.mm(k, v) |
| return q * k + r |
| |
| tensor_input = torch.randn(3, 1, 1, 1) |
| m = M().eval() |
| qconfig_dict = { |
| "": None, |
| "object_type": [ |
| (nn.Conv2d, default_qconfig), |
| ("chunk", None) |
| ] |
| } |
| # make sure it runs |
| m = prepare_fx(m, qconfig_dict) |
| m(tensor_input) |
| m = convert_fx(m) |
| m(tensor_input) |
| |
| def test_standalone_module(self): |
| class StandaloneModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| self.standalone = StandaloneModule() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.standalone(x) |
| return x |
| |
| class RefM(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| data = torch.randn(1, 1, 1, 1) |
| # instantiate M and RefM and align the parameters |
| original_m = M().eval() |
| original_ref_m = RefM().eval() |
| original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) |
| original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) |
| original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) |
| original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) |
| |
| qconfig_dict = {"": default_qconfig} |
| config_name = {"standalone_module_name": ["standalone"]} |
| config_class = {"standalone_module_class": [StandaloneModule]} |
| for prepare_config in [config_name, config_class]: |
| original_m_copy = copy.deepcopy(original_m) |
| original_ref_m_copy = copy.deepcopy(original_ref_m) |
| # check prepared model |
| m = prepare_fx( |
| original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) |
| # calibration |
| m(data) |
| # input and output of first conv, observer for standalone module |
| # will be inserted in the standalone module itself |
| count_check = { |
| ns.call_module(torch.quantization.MinMaxObserver): 2 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| # for input and output of conv in the standalone module |
| count_check = { |
| ns.call_module(torch.quantization.MinMaxObserver): 2 |
| } |
| self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) |
| |
| # check converted/quantized model |
| m = convert_fx(m) |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_module(nnq.Conv2d) : 1, |
| ns.call_method('dequantize') : 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| count_check = { |
| # standalone module will take float as input and output |
| # so we'll see quantize and dequantize in the modoule |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_module(nnq.Conv2d): 1, |
| ns.call_method('dequantize') : 1, |
| } |
| self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) |
| res = m(data) |
| |
| # quantize the reference model |
| ref_m = prepare_fx(original_ref_m_copy, qconfig_dict) |
| ref_m(data) |
| ref_m = convert_fx(ref_m) |
| ref_res = ref_m(data) |
| self.assertEqual(res, ref_res) |
| |
| @skipIfNoFBGEMM |
| def test_qconfig_none(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| self.conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"": default_qconfig, |
| "module_name": [("conv2", None)]} |
| m = prepare_fx(m, qconfig_dict) |
| data = torch.randn(1, 1, 1, 1) |
| m(data) |
| m = convert_fx(m) |
| m(data) |
| # first conv is quantized, second conv is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.Conv2d), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_module_type(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| self.conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} |
| m = prepare_fx(m, qconfig_dict) |
| data = torch.randn(1, 1, 1, 1) |
| m(data) |
| m = convert_fx(m) |
| m(data) |
| # first conv is quantized, second conv is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_function(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| |
| def forward(self, x, y): |
| return x + y |
| |
| m = M().eval() |
| qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} |
| m = prepare_fx(m, qconfig_dict) |
| data = torch.randn(1, 1, 1, 1) |
| m(data, data) |
| m = convert_fx(m) |
| m(data, data) |
| # first conv is quantized, second conv is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_module_name_regex(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1) |
| self.conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} |
| m = prepare_fx(m, qconfig_dict) |
| data = torch.randn(1, 1, 1, 1) |
| m(data) |
| m = convert_fx(m) |
| m(data) |
| # first conv is quantized, second conv is not quantized |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ] |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| def test_qconfig_precedence(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.linear = nn.Linear(1, 1) |
| self.conv = nn.Conv2d(1, 1, 1) |
| self.module_conv1 = nn.Conv2d(1, 1, 1) |
| self.module_conv2 = nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| # global |
| x = self.linear(x) |
| # global + object_type --> object_type |
| x = self.conv(x) |
| # global + object_type + module_name_regex --> module_name_regex |
| x = self.module_conv1(x) |
| # global + object_type + module_name_regex + module_name --> module_name |
| x = self.module_conv2(x) |
| return x |
| |
| m = M().eval() |
| global_qconfig = default_qconfig |
| object_type_qconfig = default_dynamic_qconfig |
| module_name_regex_qconfig = float16_dynamic_qconfig |
| module_name_qconfig = default_qat_qconfig |
| qconfig_dict = { |
| "": global_qconfig, |
| "object_type": [(nn.Conv2d, object_type_qconfig)], |
| "module_name_regex": [("module_conv*", module_name_regex_qconfig)], |
| "module_name": [("module_conv2", module_name_qconfig)]} |
| m = prepare_fx(m, qconfig_dict) |
| self.assertEqual(m.linear.qconfig, global_qconfig) |
| self.assertEqual(m.conv.qconfig, object_type_qconfig) |
| self.assertEqual(m.module_conv1.qconfig, module_name_regex_qconfig) |
| self.assertEqual(m.module_conv2.qconfig, module_name_qconfig) |
| |
| def test_remove_qconfig(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.avg_pool = torch.nn.AvgPool2d(1) |
| |
| def forward(self, x): |
| return self.avg_pool(x) |
| |
| m = M().eval() |
| qconfig_dict = {'': default_qconfig} |
| m = prepare_fx(m, qconfig_dict) |
| data = torch.randn(1, 1, 1, 1) |
| m(data) |
| m = convert_fx(m) |
| m(data) |
| for name, module in m.named_modules(): |
| self.assertFalse(hasattr(module, 'qconfig'), |
| 'qconfig is not removed for ' + name) |
| |
| def test_default_quant_after_none_qconfig(self): |
| """ Make sure default quant is inserted properly""" |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = x.transpose(1, 2) |
| x = self.conv2(x) |
| |
| m = M().eval() |
| qconfig_dict = { |
| "": default_qconfig, |
| "module_name": [ |
| ("conv1", None) |
| ] |
| } |
| m = prepare_fx(m, qconfig_dict) |
| m = convert_fx(m) |
| |
| @skipIfNoFBGEMM |
| def test_qat_and_script(self): |
| model = LinearModelWithSubmodule().train() |
| qengine = torch.backends.quantized.engine |
| qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)} |
| model = prepare_qat_fx(model, qconfig_dict) |
| |
| # ensure scripting works |
| scripted = torch.jit.script(model) |
| # run one round to make sure model runs |
| x = torch.randn(5, 5) |
| scripted(x) |
| FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \ |
| .run(scripted.graph) |
| |
| # disable fake_quant and observer |
| for epoch in range(3): |
| if epoch == 1: |
| scripted.apply(torch.quantization.disable_observer) |
| if epoch == 2: |
| scripted.apply(torch.quantization.disable_fake_quant) |
| |
| # ensure the fake_quant and observer have been disabled. |
| matches = ['.fake_quant_enabled', '.observer_enabled'] |
| for key, v in scripted.state_dict().items(): |
| if any(x in key for x in matches): |
| self.assertEqual(v, torch.tensor([0], dtype=torch.uint8)) |
| |
| # enable them back |
| scripted.apply(torch.quantization.enable_fake_quant) |
| scripted.apply(torch.quantization.enable_observer) |
| for key, v in scripted.state_dict().items(): |
| if any(x in key for x in matches): |
| self.assertEqual(v, torch.tensor([1], dtype=torch.uint8)) |
| |
| @skipIfNoFBGEMM |
| def test_save_observer_state_dict(self): |
| orig = LinearModelWithSubmodule().eval() |
| model = orig |
| qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')} |
| model = prepare_fx(model, qconfig_dict) |
| |
| # run it through input |
| x = torch.randn(5, 5) |
| model(x) |
| |
| quant = convert_fx(model) |
| |
| # save state_dict of model |
| obs_dict = torch.quantization.get_observer_state_dict(model) |
| b = io.BytesIO() |
| torch.save(obs_dict, b) |
| b.seek(0) |
| |
| # Load the stats into new model |
| model_2 = orig |
| model_2 = prepare_fx(model_2, qconfig_dict) |
| |
| loaded_dict = torch.load(b) |
| torch.quantization.load_observer_state_dict(model_2, loaded_dict) |
| |
| quant_2 = convert_fx(model_2) |
| |
| # Verify that loaded state dict produces same results. |
| self.assertEqual(quant(x), quant_2(x)) |
| |
| @skipIfNoFBGEMM |
| def test_custom_module_class(self): |
| class CustomModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class ObservedCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_float(cls, float_module): |
| assert hasattr(float_module, 'qconfig') |
| observed = cls(float_module.linear) |
| observed.qconfig = float_module.qconfig |
| return observed |
| |
| class StaticQuantCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_observed(cls, observed_module): |
| assert hasattr(observed_module, 'qconfig') |
| assert hasattr(observed_module, 'activation_post_process') |
| observed_module.linear.activation_post_process = \ |
| observed_module.activation_post_process |
| quantized = cls(nnq.Linear.from_float(observed_module.linear)) |
| return quantized |
| |
| class DynamicQuantCustomModule(torch.nn.Module): |
| def __init__(self, linear): |
| super().__init__() |
| self.linear = linear |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| @classmethod |
| def from_observed(cls, observed_module): |
| assert hasattr(observed_module, 'qconfig') |
| quantized = cls(nnqd.Linear.from_float(observed_module.linear)) |
| return quantized |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| self.custom = CustomModule() |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.custom(x) |
| return x |
| |
| class RefM(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(3, 3) |
| self.linear2 = torch.nn.Linear(3, 3) |
| |
| def forward(self, x): |
| x = self.linear1(x) |
| x = self.linear2(x) |
| return x |
| |
| data = torch.randn(3, 3) |
| # instantiate M and RefM and align the parameters |
| original_m = M().eval() |
| original_ref_m = RefM().eval() |
| original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) |
| original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) |
| original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) |
| original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) |
| |
| test_configs = { |
| "static": (default_qconfig, StaticQuantCustomModule, 3), |
| "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) |
| } |
| |
| for quant_type in [QuantType.DYNAMIC]: |
| key = quant_type_to_str(quant_type) |
| qconfig, quantized_module_class, num_observers = test_configs[key] |
| qconfig_dict = {"": qconfig} |
| if key == "static": |
| prepare_custom_config_dict = { |
| "float_to_observed_custom_module_class": { |
| "static": { |
| CustomModule: ObservedCustomModule |
| } |
| } |
| } |
| convert_custom_config_dict = { |
| "observed_to_quantized_custom_module_class": { |
| "static": { |
| ObservedCustomModule: quantized_module_class |
| } |
| } |
| } |
| else: |
| prepare_custom_config_dict = { |
| "non_traceable_module_class": [ |
| CustomModule |
| ] |
| } |
| convert_custom_config_dict = { |
| "observed_to_quantized_custom_module_class": { |
| "dynamic": { |
| CustomModule: quantized_module_class |
| } |
| } |
| } |
| |
| # check prepared model |
| m = prepare_fx( |
| original_m, |
| qconfig_dict, |
| prepare_custom_config_dict=prepare_custom_config_dict) |
| # calibration |
| m(data) |
| # all activation observers are inserted in the top level module |
| count_check = { |
| ns.call_module(torch.quantization.MinMaxObserver): num_observers |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| |
| # check converted/quantized model |
| m = convert_fx( |
| m, |
| convert_custom_config_dict=convert_custom_config_dict) |
| if quant_type == QuantType.STATIC: |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_module(nnq.Linear) : 1, |
| ns.call_method('dequantize') : 1, |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) |
| self.assertEqual(type(m.custom), quantized_module_class) |
| res = m(data) |
| |
| # quantize the reference model |
| ref_m = prepare_fx(original_ref_m, qconfig_dict) |
| ref_m(data) |
| ref_m = convert_fx(ref_m) |
| ref_res = ref_m(data) |
| self.assertEqual(res, ref_res) |
| |
| @skipIfNoFBGEMM |
| def test_non_traceable_module(self): |
| class NonTraceable(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| for k in x.keys(): |
| print(x[k]) |
| return x |
| |
| class NonTraceable2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| # data dependent control flow is not traceable |
| for i in x: |
| print(i) |
| return x |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.m1 = NonTraceable() |
| self.m2 = NonTraceable2() |
| |
| def forward(self, x): |
| x = self.m1(x) |
| x = self.m2(x) |
| return x |
| |
| m = M().eval() |
| qconfig_dict = {"": default_qconfig} |
| prepare_custom_config_dict = { |
| "non_traceable_module_name": [ |
| "m1" |
| ], |
| "non_traceable_module_class": [ |
| NonTraceable2 |
| ] |
| } |
| m = prepare_fx( |
| m, qconfig_dict, |
| prepare_custom_config_dict=prepare_custom_config_dict) |
| |
| node_occurrence = { |
| ns.call_module(NonTraceable) : 1, |
| ns.call_module(NonTraceable2) : 1, |
| } |
| # make sure these modules are not traced |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| |
| def test_prepared_model_deepcopy(self): |
| """Ensures that copy.deepcopy works correctly on a prepared model. |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| self._foobar = 'foobar' |
| self.foobar2 = 'foobar2' |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x |
| |
| m = M() |
| m.eval() |
| qconfig_dict = {'': torch.quantization.default_qconfig} |
| prepared = prepare_fx(m, qconfig_dict) |
| # calibrate |
| prepared(torch.randn(4, 1, 4, 4)) |
| # copy |
| prepared_copy = copy.deepcopy(prepared) |
| # quantize, should run with no errors |
| quantized = convert_fx(prepared_copy) |
| |
| def test_dequantize(self): |
| r""" Test to make sure dequantize node are placed before |
| non-quantizable node |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| self.act = torch.nn.GELU() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return self.act(x) |
| |
| data = torch.rand(5, 1, 3, 3, dtype=torch.float) |
| for quant_type in self.static_quant_types: |
| node_list = [ |
| ns.call_module(nnq.Conv2d), |
| ns.call_method("dequantize"), |
| ns.call_module(nn.GELU), |
| ] |
| self.checkGraphModeFxOp( |
| M().eval(), (data,), quant_type, expected_node_list=node_list) |
| |
| def test_sequential(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.convs = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 1, 1), |
| torch.nn.Conv2d(1, 1, 1) |
| ) |
| |
| def forward(self, x): |
| x = self.convs(x) |
| return x |
| |
| data = torch.rand(5, 1, 3, 3, dtype=torch.float) |
| for quant_type in self.static_quant_types: |
| node_list = [ |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ] |
| self.checkGraphModeFxOp( |
| M().eval(), (data,), quant_type, expected_node_list=node_list) |
| |
| def _test_quantized_inputs_outputs( |
| self, prepare_custom_config_dict, prepare_count_check, |
| convert_count_check): |
| """ |
| Test the option to have inputs and outputs of the graph quantized |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.conv2(x) |
| return x |
| |
| # quantized input, quantized output |
| m = M() |
| qconfig_dict = {'': torch.quantization.default_qconfig} |
| m.eval() |
| mp = torch.quantization.quantize_fx.prepare_fx( |
| m, qconfig_dict, |
| prepare_custom_config_dict=prepare_custom_config_dict) |
| self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) |
| mp(torch.randn(1, 1, 4, 4)) |
| mq = torch.quantization.quantize_fx.convert_fx(mp) |
| self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) |
| |
| def test_quantized_input_quantized_output(self): |
| prepare_custom_config_dict = { |
| 'input_quantized_idxs': [0], 'output_quantized_idxs': [0]} |
| prepare_count_check = { |
| ns.call_module(torch.quantization.MinMaxObserver): 2, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 0, |
| ns.call_method('dequantize'): 0, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| def test_fp32_input_quantized_output(self): |
| prepare_custom_config_dict = { |
| 'output_quantized_idxs': [0]} |
| prepare_count_check = { |
| ns.call_module(torch.quantization.MinMaxObserver): 3, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method('dequantize'): 0, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| def test_quantized_input_fp32_output(self): |
| prepare_custom_config_dict = { |
| 'input_quantized_idxs': [0]} |
| prepare_count_check = { |
| ns.call_module(torch.quantization.MinMaxObserver): 2, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 0, |
| ns.call_method('dequantize'): 1, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| def test_fp32_input_fp32_output(self): |
| prepare_custom_config_dict = {} |
| prepare_count_check = { |
| ns.call_module(torch.quantization.MinMaxObserver): 3, |
| } |
| convert_count_check = { |
| ns.call_function(torch.quantize_per_tensor): 1, |
| ns.call_method('dequantize'): 1, |
| } |
| self._test_quantized_inputs_outputs( |
| prepare_custom_config_dict, prepare_count_check, convert_count_check) |
| |
| @skipIfNoFBGEMM |
| class TestQuantizeFxOps(QuantizationTestCase): |
| """Unit tests for individual ops |
| """ |
| @skipIfNoFBGEMM |
| def test_linear(self): |
| class ModuleLinear(torch.nn.Module): |
| def __init__(self, has_relu=False, f_relu=False): |
| super(ModuleLinear, self).__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| if has_relu: |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = torch.nn.Identity() |
| |
| def forward(self, x): |
| return self.relu(self.linear(x)) |
| |
| class FuncLinear(torch.nn.Module): |
| def __init__(self, has_relu=False, f_relu=False): |
| super(FuncLinear, self).__init__() |
| self.w = torch.randn(4, 30) |
| self.b = torch.randn(4) |
| if has_relu: |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = torch.nn.Identity() |
| |
| def forward(self, x): |
| return self.relu(F.linear(x, self.w, self.b)) |
| |
| data = (torch.rand((1, 30), dtype=torch.float),) |
| options = itertools.product( |
| [(ModuleLinear(has_relu=False), True)], |
| # TODO: enable after raw `tensor` is supported in fx |
| # (FuncLinear(has_relu=False), False)], |
| self.all_quant_types) |
| quantized_nodes = { |
| # is_module |
| True: { |
| # quant_type: |
| QuantType.DYNAMIC: ns.call_module(nnqd.Linear), |
| QuantType.STATIC: ns.call_module(nnq.Linear), |
| # note that we are checking the final result |
| QuantType.QAT: ns.call_module(nnq.Linear), |
| }, |
| False: { |
| # quant_type: |
| QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), |
| QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), |
| QuantType.QAT: ns.call_function(torch.ops.quantized.linear), |
| } |
| } |
| for (model, is_module), quant_type in options: |
| self.checkGraphModeFxOp( |
| model, data, quant_type, quantized_nodes[is_module][quant_type]) |
| |
| for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]): |
| for model, quantized_node in [ |
| (ModuleLinear(has_relu=True, f_relu=f_relu), ns.call_module(nniq.LinearReLU))]: |
| # TODO: support functional linear + relu fusion |
| # (FuncLinear(has_relu=True, f_relu=f_relu), ns.call_function(torch.ops.quantized.linear_relu))]: |
| self.checkGraphModeFxOp(model, data, quant_type, quantized_node) |
| |
| @skipIfNoFBGEMM |
| def test_conv_module(self): |
| conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} |
| |
| class ConvWrapper(torch.nn.Module): |
| def __init__(self, dim): |
| super(ConvWrapper, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| options = itertools.product([1, 2, 3], self.static_quant_types) |
| quantized_nodes = { |
| # dim |
| 1: ns.call_module(nnq.Conv1d), |
| 2: ns.call_module(nnq.Conv2d), |
| 3: ns.call_module(nnq.Conv3d), |
| } |
| for dim, quant_type in options: |
| model = self.checkGraphModeFxOp( |
| ConvWrapper(dim), self.img_data_dict[dim], quant_type, |
| quantized_nodes[dim]) |
| |
| @skipIfNoFBGEMM |
| def test_conv2d_functional(self): |
| for bias in [True, False]: |
| conv = torch.nn.Conv2d(1, 1, 1, bias=bias) |
| # There should be 3 observers: after input, weight and activation. |
| # No observer after bias. |
| prepare_expected_node_occurrence = { |
| ns.call_module(torch.quantization.HistogramObserver): 2, |
| ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1, |
| } |
| expected_node_occurrence = \ |
| {ns.call_function(torch.ops.quantized.conv2d): 1} |
| self.checkGraphModeFxOp( |
| conv, (torch.randn(4, 1, 4, 4),), QuantType.STATIC, |
| prepare_expected_node_occurrence=prepare_expected_node_occurrence, |
| expected_node_occurrence=expected_node_occurrence, |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_conv_relu(self): |
| """tests for conv1d_relu/conv2d_relu/conv3d_relu""" |
| conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} |
| |
| class ConvNdRelu(torch.nn.Module): |
| def __init__(self, dim, inplace): |
| super(ConvNdRelu, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x): |
| return self.relu(self.conv(x)) |
| |
| class ConvNdFunctionalRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(ConvNdFunctionalRelu, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return F.relu(self.conv(x)) |
| |
| class ConvNdInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(ConvNdInplaceFunctionalRelu, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return F.relu(self.conv(x), True) |
| |
| options = itertools.product([1, 2, 3], self.static_quant_types) |
| quantized_nodes = { |
| # dim |
| 1: ns.call_module(nniq.ConvReLU1d), |
| 2: ns.call_module(nniq.ConvReLU2d), |
| 3: ns.call_module(nniq.ConvReLU3d), |
| } |
| for dim, quant_type in options: |
| for m in [ConvNdRelu(dim, True), |
| ConvNdRelu(dim, False), |
| ConvNdFunctionalRelu(dim), |
| ConvNdInplaceFunctionalRelu(dim)]: |
| self.checkGraphModeFxOp( |
| m, self.img_data_dict[dim], quant_type, |
| quantized_nodes[dim]) |
| |
| |
| def _test_quantized_binary_op_impl(self, binary_op, ibinary_op, quantized_op): |
| class Op(torch.nn.Module): |
| def __init__(self, is_inplace, is_scalar): |
| super(Op, self).__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1).float() |
| self.conv2 = torch.nn.Conv2d(1, 1, 1).float() |
| self.is_scalar = is_scalar |
| self.op = ibinary_op if is_inplace else binary_op |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = 3 if self.is_scalar else self.conv2(y) |
| # x = x + y |
| x = self.op(x, y) |
| # x = y + x |
| x = self.op(y, x) |
| return x |
| |
| # TODO: decide whether we want to quantize or not |
| # in this case |
| # class NonQuantizedOp(torch.nn.Module): |
| # def __init__(self, is_inplace, is_scalar): |
| # super(NonQuantizedOp, self).__init__() |
| # self.is_scalar = is_scalar |
| # self.op = ibinary_op if is_inplace else binary_op |
| |
| # def forward(self, x, y): |
| # y = 3 if self.is_scalar else y |
| # x = self.op(x, y) |
| # return x |
| |
| data = (torch.randn(1, 1, 1, 1, dtype=torch.float), |
| torch.randn(1, 1, 1, 1, dtype=torch.float)) |
| quantized_node = ns.call_function(quantized_op) |
| options = itertools.product([True, False], [True, False]) |
| quant_type = QuantType.STATIC |
| for is_inplace, is_scalar in options: |
| self.checkGraphModeFxOp( |
| Op(is_inplace, is_scalar), data, quant_type, quantized_node) |
| |
| def _test_quantized_binary_op_relu_impl(self, binary_op, ibinary_op, quantized_op): |
| class OpRelu(torch.nn.Module): |
| def __init__(self, is_inplace, is_functional_relu, |
| is_scalar): |
| super(OpRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1).float() |
| self.conv2 = torch.nn.Conv2d(1, 1, 1).float() |
| self.op = ibinary_op if is_inplace else binary_op |
| self.is_functional_relu = is_functional_relu |
| self.is_scalar = is_scalar |
| self.relu = F.relu if self.is_functional_relu \ |
| else torch.nn.ReLU() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = 3 if self.is_scalar else self.conv2(y) |
| x = self.op(x, y) |
| x = self.relu(x) |
| x = self.op(y, x) |
| x = self.relu(x) |
| return x |
| |
| data = (torch.rand((1, 1, 1, 1), dtype=torch.float), |
| torch.rand((1, 1, 1, 1), dtype=torch.float)) |
| quant_type = QuantType.STATIC |
| quantized_node = ns.call_function(quantized_op) |
| options = itertools.product( |
| [True, False], [True, False], [True, False]) |
| for is_inplace_op, is_functional_relu, is_scalar in options: |
| self.checkGraphModeFxOp( |
| OpRelu(is_inplace_op, is_functional_relu, is_scalar), |
| data, quant_type, quantized_node) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add(self): |
| self._test_quantized_binary_op_impl( |
| operator.add, operator.iadd, torch.ops.quantized.add) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul(self): |
| self._test_quantized_binary_op_impl( |
| operator.mul, operator.imul, torch.ops.quantized.mul) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_relu(self): |
| self._test_quantized_binary_op_relu_impl( |
| operator.add, operator.iadd, torch.ops.quantized.add_relu) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul_relu(self): |
| self._test_quantized_binary_op_relu_impl( |
| operator.mul, operator.imul, torch.ops.quantized.mul_relu) |
| |
| # TODO(future PR): make more generic |
| def _test_quantized_add_mul_qat(self, model, expected_node_occurrence): |
| qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} |
| mp = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) |
| self.checkGraphModuleNodes( |
| mp, expected_node_occurrence=expected_node_occurrence) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_qat(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = torch.add(x, 1.0) |
| x = self.conv1(x) |
| x = torch.add(x, 1.0) |
| x = torch.relu(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M() |
| expected_node_occurrence = { |
| ns.call_module(torch.quantization.FakeQuantize): 4, |
| } |
| self._test_quantized_add_mul_qat(m, expected_node_occurrence) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul_qat(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = torch.mul(x, 1.0) |
| x = self.conv1(x) |
| x = torch.mul(x, 1.0) |
| x = torch.relu(x) |
| x = self.conv2(x) |
| return x |
| |
| m = M() |
| expected_node_occurrence = { |
| ns.call_module(torch.quantization.FakeQuantize): 4, |
| } |
| self._test_quantized_add_mul_qat(m, expected_node_occurrence) |
| |
| def test_int8_input_no_unnecessary_fq(self): |
| """ |
| If the inputs to the graph are quantized and the only node |
| does not need an activation observer, verifies that the |
| activation observer is not inserted. |
| """ |
| class M(nn.Module): |
| def __init__(self, scalar): |
| super().__init__() |
| self.scalar = scalar |
| self.add_func = torch.nn.quantized.FloatFunctional() |
| |
| def forward(self, x): |
| return self.add_func.add_scalar(x, self.scalar) |
| |
| m = M(0.5) |
| mp = torch.quantization.quantize_fx.prepare_qat_fx( |
| m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')}, |
| prepare_custom_config_dict={"input_quantized_idxs": [0]}) |
| expected_node_occurrence = { |
| ns.call_module(torch.quantization.FakeQuantize): 0, |
| } |
| self.checkGraphModuleNodes( |
| mp, expected_node_occurrence=expected_node_occurrence) |
| |
| def test_quant_output_always_observed(self): |
| """ |
| If the output is hardcoded to be quantized, ensure that |
| there is always an observer, even if the last non-output node is not |
| quantizeable. |
| """ |
| qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')} |
| prepare_custom_config_dict = {'output_quantized_idxs': [0]} |
| data = (torch.randn(4, 1, 4, 4),) |
| |
| # non-quantizeable node, quantized output |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.identity = torch.nn.Identity() |
| |
| def forward(self, x): |
| x = self.identity(x) |
| return x |
| |
| m1 = M1() |
| self.checkGraphModeFxOp( |
| m1, data, QuantType.QAT, |
| prepare_expected_node_occurrence={ |
| ns.call_module(torch.quantization.FakeQuantize): 1, |
| }, |
| expected_node_occurrence={ |
| ns.call_function(torch.quantize_per_tensor): 1, |
| }, |
| prepare_custom_config_dict=prepare_custom_config_dict) |
| |
| # quantizeable node, quantized output |
| class M2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x |
| |
| m2 = M2() |
| self.checkGraphModeFxOp( |
| m2, data, QuantType.QAT, |
| prepare_expected_node_occurrence={ |
| # one for weights, one for activations |
| ns.call_module(torch.quantization.FakeQuantize): 2, |
| }, |
| expected_node_occurrence={ |
| ns.call_function(torch.quantize_per_tensor): 1, |
| }, |
| prepare_custom_config_dict=prepare_custom_config_dict) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_cat(self): |
| """ quantization of the output of cat will be depend on the |
| input of cat. we only quantize the output of cat when its inputs are quantized. |
| """ |
| class QuantizedCat(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedCat, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| return torch.cat([x, y], 1) |
| |
| # TODO: decide whether to quantize in this case |
| # class NonQuantizedCat(torch.nn.Module): |
| # def __init__(self): |
| # super(NonQuantizedCat, self).__init__() |
| |
| # def forward(self, x, y): |
| # return torch.cat([x, y], 1) |
| |
| data = (torch.randn(1, 2, 5, 5, dtype=torch.float), |
| torch.randn(1, 2, 5, 5, dtype=torch.float)) |
| quantized_node = ns.call_function(torch.ops.quantized.cat) |
| for quant_type in self.static_quant_types: |
| self.checkGraphModeFxOp(QuantizedCat(), data, quant_type, quantized_node) |
| |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm(self): |
| bn_module = { |
| # TODO: quantized batchnorm 1d module is missing |
| # 1 : torch.nn.BatchNorm1d, |
| 2 : torch.nn.BatchNorm2d, |
| 3 : torch.nn.BatchNorm3d, |
| } |
| |
| class M(torch.nn.Module): |
| def __init__(self, dim): |
| super(M, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return self.bn(x) |
| |
| options = itertools.product(self.static_quant_types, [2, 3]) |
| quantized_nodes = { |
| # 1: ns.call_module(nnq.BatchNorm1d), |
| 2: ns.call_module(nnq.BatchNorm2d), |
| 3: ns.call_module(nnq.BatchNorm3d), |
| } |
| for quant_type, dim in options: |
| model = self.checkGraphModeFxOp( |
| M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[dim]) |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm_relu(self): |
| bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} |
| |
| class BNRelu(torch.nn.Module): |
| def __init__(self, dim, inplace): |
| super(BNRelu, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| self.relu = torch.nn.ReLU(inplace=inplace) |
| |
| def forward(self, x): |
| return self.relu(self.bn(x)) |
| |
| class BNFuncRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(BNFuncRelu, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return F.relu(self.bn(x), False) |
| |
| class BNFuncInplaceRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(BNFuncInplaceRelu, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return F.relu(self.bn(x), True) |
| |
| options = itertools.product(self.static_quant_types, [2, 3]) |
| quantized_nodes = { |
| 2: ns.call_module(nniq.BNReLU2d), |
| 3: ns.call_module(nniq.BNReLU3d), |
| } |
| for quant_type, dim in options: |
| for instance in [BNRelu(dim, True), BNRelu(dim, False), |
| BNFuncRelu(dim), BNFuncInplaceRelu(dim)]: |
| self.checkGraphModeFxOp( |
| instance, self.img_data_dict[dim], quant_type, |
| quantized_nodes[dim]) |
| |
| def _test_activation_impl( |
| self, float_module, float_op, quantized_module, quantized_op): |
| ''' Test for activation op(with inplace options), float_op can be |
| torch op or functional op |
| ''' |
| class M(torch.nn.Module): |
| def __init__(self, is_module, inplace): |
| super(M, self).__init__() |
| self.is_module = is_module |
| self.inplace = inplace |
| if self.is_module: |
| self.op = float_module(self.inplace) |
| else: |
| self.op = float_op |
| |
| def forward(self, input): |
| if self.is_module: |
| return self.op(input) |
| else: |
| return self.op(input, self.inplace) |
| |
| options = itertools.product([True, False], [True, False], self.static_quant_types) |
| quantized_nodes = { |
| # is_module |
| True: ns.call_module(quantized_module), |
| False: ns.call_function(quantized_op), |
| } |
| |
| for is_module, is_inplace, quant_type in options: |
| self.checkGraphModeFxOp( |
| M(is_module, is_inplace), self.img_data_2d, |
| quant_type, quantized_nodes[is_module]) |
| |
| def test_hardswish(self): |
| self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish) |
| |
| def test_elu(self): |
| self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu) |
| |
| def test_leaky_relu(self): |
| self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) |
| |
| def _test_norm_impl( |
| self, float_module, float_op, op_args, data, quantized_module, quantized_op, |
| skip_op_arg_for_functional=False): |
| ''' Test for normalization op, float_op can be torch op or functional op, |
| op_args is a list of positional argument for the module/op |
| ''' |
| class M(torch.nn.Module): |
| def __init__(self, is_module): |
| super(M, self).__init__() |
| self.is_module = is_module |
| if self.is_module: |
| self.op = float_module(*op_args) |
| else: |
| self.op = float_op |
| |
| def forward(self, input): |
| if self.is_module: |
| return self.op(input) |
| else: |
| args = [input] |
| if not skip_op_arg_for_functional: |
| args += op_args |
| return self.op(*args) |
| |
| options = itertools.product([True, False], self.static_quant_types) |
| quantized_nodes = { |
| # is_module |
| True: ns.call_module(quantized_module), |
| False: ns.call_function(quantized_op), |
| } |
| |
| for is_module, quant_type in options: |
| self.checkGraphModeFxOp( |
| M(is_module), data, quant_type, quantized_nodes[is_module]) |
| |
| def test_layer_norm(self): |
| data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) |
| self._test_norm_impl( |
| nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm) |
| |
| def test_instance_norm(self): |
| data_1d = (torch.rand((1, 4, 5), dtype=torch.float),) |
| data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),) |
| data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),) |
| data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d} |
| instance_norm_modules = {1 : nn.InstanceNorm1d, |
| 2 : nn.InstanceNorm2d, |
| 3 : nn.InstanceNorm3d} |
| quantized_instance_norm_modules = { |
| 1 : nnq.InstanceNorm1d, |
| 2 : nnq.InstanceNorm2d, |
| 3 : nnq.InstanceNorm3d |
| } |
| for dim in [1, 2, 3]: |
| data = data_dict[dim] |
| module = instance_norm_modules[dim] |
| quantized_module = quantized_instance_norm_modules[dim] |
| self._test_norm_impl( |
| module, F.instance_norm, [4], data, |
| quantized_module, torch.ops.quantized.instance_norm, |
| skip_op_arg_for_functional=True) |
| |
| @skipIfNoFBGEMM |
| def test_clamp(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu6 = torch.nn.ReLU6() |
| self.relu6_ = torch.nn.ReLU6(True) |
| self.hardtanh = torch.nn.Hardtanh() |
| self.hardtanh_ = torch.nn.Hardtanh(inplace=True) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.relu6(x) |
| self.relu6_(x) |
| x = F.relu6(x) |
| x = torch.clamp(x, -3, 3) |
| x = x.clamp(-2.5, 2.5) |
| # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready |
| x = self.hardtanh(x) |
| self.hardtanh_(x) |
| x = F.hardtanh(x) |
| F.hardtanh_(x) |
| return x |
| |
| data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) |
| # list of node that should occur in order |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_function(F.hardtanh_), |
| ns.call_method('dequantize') |
| ] |
| for quant_type in self.static_quant_types: |
| m = self.checkGraphModeFxOp( |
| M(), data, quant_type, expected_node_list=node_list) |
| |
| @skipIfNoFBGEMM |
| def test_general_shape_ops(self): |
| """ A test that checks dequantize will be swapped for |
| all supported general shape ops like aten::flatten |
| without actually checking for execution of these ops |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) |
| self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) |
| self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) |
| self.dropout = torch.nn.Dropout() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| self.conv2 = torch.nn.Conv2d(3, 3, 3) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| # add_scalar |
| x = x + 3 |
| # mul_scalar |
| x = x * 3 |
| # add_scalar_out |
| x += 3 |
| # mul_scalar_out |
| x *= 3 |
| # add_scalar_relu |
| x = x + 3 |
| x = F.relu(x) |
| # add_scalar_relu_out |
| x += 3 |
| x = F.relu(x) |
| # mul_scalar_relu |
| x = x * 3 |
| x = F.relu(x) |
| # mul_scalar_relu_out |
| x *= 3 |
| x = F.relu(x) |
| x = self.maxpool1d(x) |
| x = self.maxpool2d(x) |
| x = self.maxpool3d(x) |
| x = torch.flatten(x) |
| x = torch.max(x) |
| x = torch.min(x) |
| x = x.reshape([-1]) |
| x = x.resize_(1, 1, x.numel()) |
| x = x.view(-1) |
| # prim::ListConstruct |
| xs = [x, x] |
| # prim::ListUnpack |
| x, y = xs |
| # prim::TupleConstruct |
| xs = (x, x) |
| # prim::TupleUnpack |
| x, y = xs |
| x = x.transpose(1, 2) |
| x = x.contiguous() |
| x, y = torch.chunk(x, 2) |
| x = F.dropout(x) |
| x = self.dropout(x) |
| x, _ = torch.sort(x) |
| x = x.permute(0, 2, 3, 1) |
| x = x.repeat_interleave(3, 1) |
| x = torch.repeat_interleave(x, 3, 1) |
| x = self.relu(x) |
| x = F.relu(x) |
| x = F.relu(x, inplace=True) |
| x = x.relu() |
| x.relu_() |
| x = x.squeeze(0) |
| x.squeeze_(0) |
| x = torch.squeeze(x, 0) |
| x = x.unsqueeze(0) |
| x.unsqueeze_(0) |
| x = torch.unsqueeze(x, 0) |
| x = x.detach() |
| x.detach_() |
| x = x.repeat(4, 2) |
| y = [] |
| y.append(x) |
| z = torch.stack(y, 0) |
| z = [z, z] |
| x, _ = z |
| x = self.conv2(x) |
| return x |
| |
| data = torch.rand(1, 3, 10, 10) |
| # This model is not executable since we just put all ops |
| # in the same forward |
| m = M().eval() |
| # nothing to fuse so skipping the fuse step |
| qconfig_dict = {'': default_qconfig} |
| prepared = prepare_fx(m, qconfig_dict) |
| # not runnable |
| quantized = convert_fx(prepared) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers and also successfully fused two quantized::conv2d |
| # patterns |
| # one quantize_per_tensor for input |
| # check exact counts of quantize and dequantize |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_method('dequantize') : 1 |
| } |
| order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized, |
| expected_node_occurrence=count_check, |
| expected_node_list=order_check) |
| |
| @skipIfNoFBGEMM |
| def test_general_value_ops(self): |
| """ A test that checks correct patterns are produced for |
| all supported general value ops like aten::avg_pool2d \ |
| without actually checking for execution of these ops |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.avg_pool1d = torch.nn.AvgPool1d(3) |
| self.avg_pool2d = torch.nn.AvgPool2d(3) |
| self.avg_pool3d = torch.nn.AvgPool3d(3) |
| self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) |
| self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.avg_pool1d(x) |
| x = self.avg_pool2d(x) |
| x = self.avg_pool3d(x) |
| x = self.adaptive_avg_pool1d(x) |
| x = self.adaptive_avg_pool2d(x) |
| x = self.adaptive_avg_pool3d(x) |
| x = F.avg_pool1d(x, 3) |
| x = F.avg_pool2d(x, 3) |
| x = F.avg_pool3d(x, 3) |
| x = F.adaptive_avg_pool1d(x, (1)) |
| x = F.adaptive_avg_pool2d(x, (1, 1)) |
| x = F.adaptive_avg_pool3d(x, (1, 1, 1)) |
| x = torch.mean(x) |
| x = torch.mean(x, [2, 3], False) |
| x = x.mean() |
| x = x.mean([2, 3], True) |
| x = F.interpolate(x, 4, mode='nearest') |
| x = F.interpolate(x, 4, mode='linear') |
| x = self.conv(x) |
| return x |
| |
| # This model is not executable since we just put all ops |
| # in the same forward |
| m = M().eval() |
| # nothing to fuse so skipping the fuse step |
| qconfig_dict = {'': default_qconfig} |
| prepared = prepare_fx(m, qconfig_dict) |
| # not runnable |
| quantized = convert_fx(prepared) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers |
| # check exact counts of quantize and dequantize |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_method('dequantize') : 1 |
| } |
| order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized, |
| expected_node_occurrence=count_check, |
| expected_node_list=order_check) |
| |
| @skipIfNoFBGEMM |
| def test_fixed_qparams_ops(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.sigmoid = torch.nn.Sigmoid() |
| self.hardsigmoid = torch.nn.Hardsigmoid() |
| self.tanh = torch.nn.Tanh() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| # F.sigmoid is deprecated |
| x = self.sigmoid(x) |
| x = torch.sigmoid(x) |
| x = x.sigmoid() |
| x.sigmoid_() |
| x = self.hardsigmoid(x) |
| x = F.hardsigmoid(x) |
| x = F.hardsigmoid(x, inplace=True) |
| x = x.hardsigmoid() |
| x.hardsigmoid_() |
| x = self.tanh(x) |
| # F.tanh is deprecated |
| x = torch.tanh(x) |
| x = x.tanh() |
| x.tanh_() |
| x = self.conv(x) |
| return x |
| |
| for eval_mode in [True, False]: |
| # This model is not executable since we just put all ops |
| # in the same forward |
| m = M() |
| if eval_mode: |
| m.eval() |
| qconfig = default_qconfig |
| prepare = prepare_fx |
| fq_count = 0 |
| else: |
| m.train() |
| qconfig = default_qat_qconfig |
| prepare = prepare_qat_fx |
| fq_count = 13 |
| |
| # nothing to fuse so skipping the fuse step |
| qconfig_dict = {'': qconfig} |
| prepared = prepare(m, qconfig_dict) |
| # check the correct number of activation_post_process is inserted |
| count_check = { |
| ns.call_module(FixedQParamsFakeQuantize) : fq_count, |
| } |
| self.checkGraphModuleNodes( |
| prepared, |
| expected_node_occurrence=count_check) |
| # not runnable |
| quantized = convert_fx(prepared) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers |
| # check exact counts of quantize and dequantize |
| count_check = { |
| ns.call_function(torch.quantize_per_tensor) : 1, |
| ns.call_method('dequantize') : 1 |
| } |
| order_check = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_module(nnq.Conv2d), |
| ns.call_module(nn.Sigmoid), |
| ns.call_module(nnq.Conv2d), |
| ns.call_method('dequantize'), |
| ] |
| self.checkGraphModuleNodes( |
| quantized, |
| expected_node_occurrence=count_check, |
| expected_node_list=order_check) |
| |
| def test_float_functional(self): |
| class TorchAdd(nn.Module): |
| """Wrapper around torch.add so that all ops can be found at build""" |
| def __init__(self): |
| super().__init__() |
| self.add_func = nnq.FloatFunctional() |
| |
| def forward(self, x, y): |
| return self.add_func.add(x, y) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.ff1 = TorchAdd() |
| self.ff2 = nnq.FloatFunctional() |
| self.ff3 = nnq.FloatFunctional() |
| self.ff4 = nnq.FloatFunctional() |
| self.ff5 = nnq.FloatFunctional() |
| self.ff6 = nnq.FloatFunctional() |
| |
| def forward(self, x): |
| x = self.ff1(x, x) |
| x = self.ff2.add_scalar(x, 3) |
| x = self.ff3.mul(x, x) |
| x = self.ff4.mul_scalar(x, 3) |
| x = self.ff5.add_relu(x, x) |
| x = self.ff6.cat([x]) |
| return x |
| |
| data = torch.rand(3, 3) |
| # Note: QAT test succeeded by chance, to make it actually work |
| # we need to fix eager mode FloatFunctional by removing |
| # activation_post_process in add_scalar and mul_scalar |
| for quant_type in self.static_quant_types: |
| m = M() |
| ref_m = torch.quantization.QuantWrapper(M()) |
| is_qat = quant_type == QuantType.QAT |
| if is_qat: |
| m.train() |
| ref_m.train() |
| qconfig = default_qat_qconfig |
| expected_act_post_process = torch.quantization.FakeQuantize |
| else: |
| m.eval() |
| ref_m.eval() |
| qconfig = default_qconfig |
| expected_act_post_process = torch.quantization.MinMaxObserver |
| |
| prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx |
| qconfig_dict = {"": qconfig} |
| m = prepare_fx_function(m, qconfig_dict) |
| node_occurrence = { |
| ns.call_module(expected_act_post_process): 5, |
| ns.call_module(torch.nn.quantized.FloatFunctional): 0 |
| } |
| self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) |
| m(data) |
| node_list = [ |
| ns.call_function(torch.quantize_per_tensor), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_function(torch.ops.quantized.add), |
| ns.call_function(torch.ops.quantized.mul), |
| ns.call_function(torch.ops.quantized.mul), |
| ns.call_function(torch.ops.quantized.add_relu), |
| ns.call_function(torch.ops.quantized.cat), |
| ns.call_method('dequantize') |
| ] |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node_list=node_list) |
| |
| # make sure numerics match with eager mode |
| ref_m.qconfig = qconfig |
| prepare_function = prepare_qat if is_qat else prepare |
| ref_m = prepare_function(ref_m) |
| ref_m(data) |
| ref_m = convert(ref_m) |
| self.assertEqual(m(data), ref_m(data)) |
| |
| def test_embedding(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) |
| |
| def forward(self, indices): |
| return self.emb(indices) |
| |
| model = M().eval() |
| indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) |
| quantized_node = ns.call_module(nnq.Embedding) |
| configs = [ |
| (float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)), |
| (None, ns.call_module(nn.Embedding)), |
| (default_qconfig, ns.call_module(nn.Embedding)), |
| ] |
| |
| for qconfig, node in configs: |
| qconfig_dict = {"": qconfig} |
| m = prepare_fx(model, qconfig_dict) |
| self.checkGraphModuleNodes(m, expected_node_occurrence={ |
| ns.call_module(torch.quantization.MinMaxObserver): 0 |
| }) |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node=node) |
| # make sure it runs |
| m(indices) |
| |
| def test_embedding_bag(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True) |
| |
| def forward(self, indices, offsets): |
| return self.emb(indices, offsets) |
| |
| indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) |
| offsets = torch.tensor([0, 19, 20, 28, 28, 32]) |
| quantized_node = ns.call_module(nnq.EmbeddingBag) |
| inputs = (indices, offsets) |
| |
| for dtype in [torch.quint8, torch.quint4x2]: |
| model = M().eval() |
| float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, |
| qscheme=torch.per_channel_affine_float_qparams, |
| ch_axis=0) |
| float_qparams_qconfig = QConfigDynamic(activation=default_placeholder_observer, |
| weight=float_qparams_observer) |
| self.checkGraphModeFxOp( |
| model, |
| inputs, |
| QuantType.DYNAMIC, |
| quantized_node, |
| custom_qconfig=float_qparams_qconfig |
| ) |
| |
| # check it works in None and static qconfig |
| for qconfig in [None, default_qconfig]: |
| qconfig_dict = {"": default_qconfig} |
| m = M().eval() |
| m = prepare_fx(model, qconfig_dict) |
| self.checkGraphModuleNodes(m, expected_node_occurrence={ |
| ns.call_module(torch.quantization.MinMaxObserver): 0 |
| }) |
| m = convert_fx(m) |
| self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) |
| # make sure it runs |
| m(*inputs) |
| |
| def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): |
| options = itertools.product(qconfigs, module_type_strs) |
| for qconfig, module_type_str in options: |
| model_eager = M(module_type_str).eval() |
| model_graph = copy.deepcopy(model_eager) |
| if torch.backends.quantized.engine == 'qnnpack' and \ |
| qconfig is float16_dynamic_qconfig: |
| continue |
| # fp16 dynamic quant is not supported for qnnpack |
| |
| eager_qconfig_dict = {x : qconfig for x in module_types} |
| model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) |
| |
| graph_qconfig_dict = { |
| "object_type": [ |
| (x, qconfig) for x in module_types |
| ] |
| } |
| model_graph = prepare_fx(model_graph, graph_qconfig_dict) |
| model_graph = convert_fx(model_graph) |
| self.assertEqual(model_eager(sample_input), model_graph(sample_input)) |
| self.checkScriptable(model_graph, [[sample_input]], True) |
| |
| def test_rnn_cell(self): |
| qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] |
| module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] |
| module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] |
| sample_input = torch.tensor([[100, -155], |
| [-155, 100], |
| [100, -155]], dtype=torch.float) |
| self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) |
| |
| def test_rnn(self): |
| qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] |
| module_type_strs = ['LSTM'] |
| module_types = [torch.nn.LSTM] |
| niter = 10 |
| sample_input = torch.tensor([[100, -155], |
| [-155, 100], |
| [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) |
| self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) |
| |
| |
| class TestQuantizeFxModels(QuantizationTestCase): |
| def _test_model_impl( |
| self, mode, name, model, eager_quantizable_model, |
| check_with_eager=True, |
| diff_of_quant=None, |
| diff_from_eager=None): |
| if diff_of_quant is None or diff_from_eager is None: |
| diff_of_quant = {} |
| diff_from_eager = {} |
| |
| if mode not in diff_of_quant or mode not in diff_from_eager: |
| diff_of_quant[mode] = {} |
| diff_from_eager[mode] = {} |
| |
| input_tensor = torch.rand(1, 3, 224, 224) |
| input_tensor_inception = torch.rand(1, 3, 299, 299) |
| output_value = torch.randint(0, 1, (1,)) |
| |
| # print('quantizing:', name, ' mode:', mode) |
| if name == 'inception_v3': |
| input_value = input_tensor_inception |
| else: |
| input_value = input_tensor |
| |
| qconfig = default_qconfig if mode == 'static' else default_qat_qconfig |
| qconfig_dict = {'': qconfig} |
| # print('graph module:', graph_module.src) |
| script = torch.jit.script(model) |
| |
| # make sure graph module and script module are both runanble |
| original_out = model(input_value) |
| is_not_tuple_out = not isinstance(original_out, tuple) |
| script_out = script(input_value) |
| |
| # set to train just before quantization |
| prepare_fx_fn = prepare_fx |
| if mode != 'static': |
| model.train() |
| prepare_fx_fn = prepare_qat_fx |
| |
| prepared = prepare_fx_fn(model, qconfig_dict) |
| |
| if mode == 'ddp': |
| mp.spawn(run_ddp, |
| args=(world_size, prepared), |
| nprocs=world_size, |
| join=True) |
| elif mode == 'qat': |
| assert prepared.training, 'prepared must be in training mode for qat' |
| optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) |
| criterion = nn.CrossEntropyLoss() |
| train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) |
| else: |
| for i in range(10): |
| prepared(input_value) |
| |
| # print('after observation root:', prepared.root) |
| |
| qgraph = convert_fx(prepared) |
| # print('after quantization root:', qgraph.root) |
| # print('after quantization code:', qgraph.src) |
| qgraph.eval() |
| qgraph_script = torch.jit.script(qgraph) |
| # print('quantized and scripted:', qgraph_script.graph) |
| |
| qgraph_out = qgraph(input_value) |
| qgraph_script = qgraph_script(input_value) |
| |
| if is_not_tuple_out: |
| diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() |
| assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' |
| else: |
| print('tuple output') |
| |
| if eager_quantizable_model is not None: |
| # comparing to eager mode quantization |
| qeager = eager_quantizable_model |
| ref_out = qeager(input_value) |
| qeager.qconfig = qconfig |
| if mode == 'static': |
| qeager.fuse_model() |
| prepare(qeager, inplace=True) |
| else: |
| qeager.train() |
| qeager.fuse_model() |
| prepare_qat(qeager, inplace=True) |
| |
| # calibration |
| if mode == 'ddp': |
| mp.spawn(run_ddp, |
| args=(world_size, qeager), |
| nprocs=world_size, |
| join=True) |
| elif mode == 'qat': |
| assert qeager.training, 'qeager should be in training mode for qat' |
| optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) |
| train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) |
| else: |
| for i in range(10): |
| qeager(input_value) |
| |
| # print('ref after observation:', qeager) |
| |
| convert(qeager, inplace=True) |
| qeager.eval() |
| |
| # print('ref after quantization:', qeager) |
| qeager_out = qeager(input_value) |
| qeager_script = torch.jit.script(qeager) |
| qscript_out = qeager_script(input_value) |
| if is_not_tuple_out: |
| diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() |
| if check_with_eager: |
| self.assertEqual(diff_from_eager[mode][name], 0, |
| 'Result of graph mode quantization and ' + |
| 'eager mode quantization on model: ' + name + |
| ' should match. Mode: ' + mode + |
| ' diff:' + str(diff_from_eager[mode][name])) |
| |
| def _test_building_block(self, quant_type, BB): |
| eager = BB().float() |
| graph = copy.deepcopy(eager) |
| |
| if quant_type == QuantType.STATIC: |
| qconfig = default_qconfig |
| eager_prepare = prepare |
| graph_prepare = prepare_fx |
| eager.eval() |
| graph.eval() |
| calibrate_or_train = test_only_eval_fn |
| data = self.img_data_2d |
| else: |
| assert quant_type == QuantType.QAT |
| qconfig = default_qat_qconfig |
| eager_prepare = prepare_qat |
| graph_prepare = prepare_qat_fx |
| eager.train() |
| graph.train() |
| calibrate_or_train = test_only_train_fn |
| data = self.img_data_2d_train |
| |
| if hasattr(eager, "fuse_model"): |
| eager.fuse_model() |
| eager = QuantWrapper(eager) |
| eager.qconfig = qconfig |
| eager = eager_prepare(eager) |
| |
| qconfig_dict = {"": qconfig} |
| graph = graph_prepare(graph, qconfig_dict) |
| |
| eager_out = eager(data[0][0]) |
| graph_out = graph(data[0][0]) |
| self.assertEqual(eager_out, graph_out) |
| |
| calibrate_or_train(eager, data) |
| calibrate_or_train(graph, data) |
| |
| eager = convert(eager) |
| graph = convert_fx(graph) |
| |
| eager_out = eager(data[0][0]) |
| graph_out = graph(data[0][0]) |
| self.assertEqual(eager_out, graph_out) |
| |
| @override_qengines |
| def test_resnet_base(self): |
| models = [ResNetBase] |
| options = itertools.product(self.static_quant_types, models) |
| for quant_type, M in options: |
| self._test_building_block(quant_type, M) |
| |
| @skip_if_no_torchvision |
| @skipIfNoFBGEMM |
| @unittest.skip("skip for now since tbb failed") |
| def test_torchvision(self): |
| from torchvision import models |
| from torchvision.models import quantization as quantized_models |
| |
| def get_available_classification_models(models): |
| return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] |
| |
| model_list = get_available_classification_models(models) |
| quantized_model_list = get_available_classification_models(quantized_models) |
| |
| no_pretrained_model = set(['shufflenet_v2_x0_5', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0']) |
| quantized_model_list = set(quantized_model_list) - no_pretrained_model |
| # test eager and graph consistency |
| model_list = quantized_model_list |
| # inception_v3 is not symbolically traceable: https://github.com/pytorch/pytorch/issues/48813 |
| model_list = set(model_list) - {'inception_v3'} |
| # mobilenet: dropout error RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'QUInt8' |
| # incpetion_v3: looks like there is some problem with AuxLogits |
| quantized_not_working = [('qat', 'inception_v3'), |
| ('static', 'inception_v3')] |
| |
| fx_eager_not_matching = ['googlenet', # because _transform_input is not quantized in eager |
| 'mobilenet_v2'] # because relu6 is replaced as relu in mobilenetv2 |
| |
| diff_of_quant = {} |
| diff_from_eager = {} |
| modes = ['static', 'qat'] |
| options = itertools.product(modes, model_list) |
| for mode, name in options: |
| pretrained = name in quantized_model_list # load pretrained model to compare with quantized model |
| if name in quantized_model_list: |
| if (mode, name) in quantized_not_working: |
| eager_quantizable_model = None |
| else: |
| eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float() |
| # compare with eager mode quantized model when it is available |
| pretrained = eager_quantizable_model is not None |
| model = models.__dict__[name](pretrained=pretrained).eval().float() |
| check_with_eager = name not in fx_eager_not_matching |
| self._test_model_impl( |
| mode, name, model, eager_quantizable_model, |
| check_with_eager, |
| diff_of_quant, diff_from_eager) |
| |
| def print_diffs(diffs): |
| for mode, diffs_for_mode in diffs.items(): |
| print('mode:', mode) |
| for name, diff in diffs_for_mode.items(): |
| print(name, ':', diff) |
| |
| # print('differences between float and quantized') |
| # print_diffs(diff_of_quant) |
| # print('----------------------') |
| # print('differences between graph mode and eager mode') |
| # print_diffs(diff_from_eager) |
| # print('----------------------') |
| |
| @skip_if_no_torchvision |
| @skip_if_not_multigpu |
| @skipIfNoFBGEMM |
| def test_resnet18_ddp(self): |
| from torchvision import models |
| from torchvision.models import quantization as quantized_models |
| eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float() |
| model = models.__dict__[name](pretrained=True).eval().float() |
| self._test_model_impl( |
| 'ddp', 'resnet18', model, eager_quantizable_model) |