| import unittest |
| import torch |
| import torch.nn as nn |
| import torch.nn.quantized as nnq |
| import torch.nn._intrinsic as nni |
| import torch.nn._intrinsic.quantized as nniq |
| import torch.nn._intrinsic.qat as nniqat |
| from torch.quantization import \ |
| QConfig_dynamic, default_weight_observer, \ |
| quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \ |
| quantize_dynamic, default_qconfig, default_qat_qconfig, \ |
| default_dynamic_qconfig, MinMaxObserver, QuantWrapper |
| |
| from common_utils import run_tests, tempfile |
| from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \ |
| SkipQuantModel, QuantStubModel, \ |
| ModelForFusion, ManualLinearQATModel, ManualConvLinearQATModel, \ |
| ModForWrapping, \ |
| test_only_eval_fn, test_only_train_fn, \ |
| prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \ |
| TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel |
| |
| from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \ |
| AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel |
| |
| from hypothesis import given |
| from hypothesis import strategies as st |
| import io |
| import copy |
| |
| @unittest.skipIf( |
| not torch.fbgemm_is_cpu_supported(), |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.", |
| ) |
| class PostTrainingQuantTest(QuantizationTestCase): |
| def test_single_layer(self): |
| r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped |
| to nnq.Linear which is the quantized version of the module |
| """ |
| model = SingleLayerLinearModel() |
| prepare(model) |
| # Check if observers and quant/dequant nodes are inserted |
| self.checkNoPrepModules(model) |
| self.checkHasPrepModules(model.fc1) |
| self.checkObservers(model) |
| |
| test_only_eval_fn(model, self.calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model) |
| self.checkHasPrepModules(model.fc1) |
| self.checkWrappedQuantizedLinear(model.fc1) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(SingleLayerLinearModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| def test_two_layers(self): |
| r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one |
| `fc2`, and `fc1`is not quantized |
| """ |
| model = AnnotatedTwoLayerLinearModel() |
| prepare(model) |
| |
| self.checkNoPrepModules(model) |
| self.checkObservers(model) |
| self.checkNoPrepModules(model.fc1) |
| self.checkHasPrepModules(model.fc2) |
| |
| test_only_eval_fn(model, self.calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.fc1) |
| self.checkHasPrepModules(model.fc2) |
| self.assertEqual(type(model.fc1), torch.nn.Linear) |
| self.checkWrappedQuantizedLinear(model.fc2) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| def test_nested1(self): |
| r"""Test quantization for nested model, top level 'fc3' and |
| 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized |
| """ |
| model = AnnotatedNestedModel() |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkNoPrepModules(model.sub2) |
| self.checkHasPrepModules(model.sub2.fc1) |
| self.checkNoPrepModules(model.sub2.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| prepare(model) |
| checkPrepModules(model, True) |
| test_only_eval_fn(model, self.calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkLinear(model.sub1.fc) |
| self.checkWrappedQuantizedLinear(model.fc3) |
| self.checkWrappedQuantizedLinear(model.sub2.fc1) |
| self.checkLinear(model.sub2.fc2) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedNestedModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| |
| def test_nested2(self): |
| model = AnnotatedSubNestedModel() |
| prepare(model) |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkHasPrepModules(model.sub2) |
| self.checkNoPrepModules(model.sub2.module.fc1) |
| self.checkNoPrepModules(model.sub2.module.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| checkPrepModules(model, True) |
| |
| test_only_eval_fn(model, self.calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkLinear(model.sub1.fc) |
| self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) |
| self.checkQuantizedLinear(model.sub2.module.fc1) |
| self.checkQuantizedLinear(model.sub2.module.fc2) |
| self.checkWrappedQuantizedLinear(model.fc3) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| def test_nested3(self): |
| r"""More complicated nested test case with child qconfig overrides |
| parent qconfig |
| """ |
| model = AnnotatedCustomConfigNestedModel() |
| prepare(model) |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkNoPrepModules(model.sub2) |
| self.checkHasPrepModules(model.sub2.fc1) |
| self.checkHasPrepModules(model.sub2.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| checkPrepModules(model, True) |
| |
| test_only_eval_fn(model, self.calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkWrappedQuantizedLinear(model.sub2.fc1) |
| self.checkWrappedQuantizedLinear(model.sub2.fc2) |
| self.checkWrappedQuantizedLinear(model.fc3) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| def test_skip_quant(self): |
| r"""The case when we want to skip quantizing some layers |
| """ |
| |
| model = SkipQuantModel() |
| prepare(model) |
| self.checkObservers(model) |
| |
| test_only_eval_fn(model, self.calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.checkLinear(model.fc) |
| self.checkQuantDequant(model.sub) |
| self.checkQuantizedLinear(model.sub.module.fc1) |
| self.checkQuantizedLinear(model.sub.module.fc2) |
| self.assertEqual(type(model.sub.module.relu), nnq.ReLU) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(SkipQuantModel(), test_only_eval_fn, self.calib_data) |
| checkQuantized(model) |
| |
| |
| def test_manual(self): |
| r"""User inserts QuantStub and DeQuantStub in model code |
| and call the quantization utility functions. |
| """ |
| model = QuantStubModel() |
| # propagate the qconfig of parents to children, model is changed |
| # inplace |
| prepare(model) |
| self.checkObservers(model) |
| |
| test_only_eval_fn(model, self.calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc), nnq.Linear) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data) |
| checkQuantized(model) |
| |
| def test_resnet_base(self): |
| r"""Test quantization for bottleneck topology used in resnet/resnext |
| and add coverage for conversion of average pool and float functional |
| """ |
| model = ResNetBase().float().eval() |
| model = QuantWrapper(model) |
| model.qconfig = default_qconfig |
| fuse_list = [['module.conv1', 'module.bn1', 'module.relu1']] |
| fuse_modules(model, fuse_list) |
| prepare(model) |
| self.checkObservers(model) |
| test_only_eval_fn(model, self.img_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d) |
| self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) |
| self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) |
| test_only_eval_fn(model, self.img_data) |
| |
| checkQuantized(model) |
| |
| @unittest.skipIf( |
| not torch.fbgemm_is_cpu_supported(), |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.", |
| ) |
| class PostTrainingDynamicQuantTest(QuantizationTestCase): |
| def test_single_layer(self): |
| r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module, |
| make sure it is swapped to nnqd.Linear which is the quantized version of |
| the module |
| """ |
| model = SingleLayerLinearDynamicModel().eval() |
| qconfig_dict = { |
| '': default_dynamic_qconfig |
| } |
| prepare_dynamic(model, qconfig_dict) |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkDynamicQuantizedLinear(model.fc1) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(SingleLayerLinearDynamicModel().eval(), |
| qconfig_dict) |
| checkQuantized(model) |
| |
| def test_two_layers(self): |
| r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one |
| `fc2`, and `fc1`is not quantized |
| """ |
| model = TwoLayerLinearModel().eval() |
| qconfig_dict = { |
| 'fc2': default_dynamic_qconfig |
| } |
| prepare_dynamic(model, qconfig_dict) |
| |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc1), torch.nn.Linear) |
| self.checkDynamicQuantizedLinear(model.fc2) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict) |
| checkQuantized(model) |
| |
| def test_nested1(self): |
| r"""Test quantization for nested model, top level 'fc3' and |
| 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized |
| """ |
| model = NestedModel().eval() |
| qconfig_dict = { |
| 'fc3': default_dynamic_qconfig, |
| 'sub2.fc1': default_dynamic_qconfig |
| } |
| |
| prepare_dynamic(model, qconfig_dict) |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkLinear(model.sub1.fc) |
| self.checkDynamicQuantizedLinear(model.fc3) |
| self.checkDynamicQuantizedLinear(model.sub2.fc1) |
| self.checkLinear(model.sub2.fc2) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dict) |
| checkQuantized(model) |
| |
| def test_nested2(self): |
| r"""Another test case for quantized, we will quantize all submodules |
| of submodule sub2 |
| """ |
| model = NestedModel().eval() |
| qconfig_dict = { |
| 'fc3': default_dynamic_qconfig, |
| 'sub2': default_dynamic_qconfig |
| } |
| prepare_dynamic(model, qconfig_dict) |
| |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkLinear(model.sub1.fc) |
| self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) |
| self.checkDynamicQuantizedLinear(model.sub2.fc1) |
| self.checkDynamicQuantizedLinear(model.sub2.fc2) |
| self.checkDynamicQuantizedLinear(model.fc3) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dict) |
| checkQuantized(model) |
| |
| def test_nested3(self): |
| r"""More complicated nested test case with child qconfig overrides |
| parent qconfig |
| """ |
| model = NestedModel().eval() |
| custum_options = { |
| 'dtype': torch.quint8, |
| 'qscheme': torch.per_tensor_affine |
| } |
| custom_dynamic_qconfig = QConfig_dynamic(weight=default_weight_observer()) |
| qconfig_dynamic_dict = { |
| 'fc3': default_dynamic_qconfig, |
| 'sub2': default_dynamic_qconfig, |
| 'sub2.fc1': custom_dynamic_qconfig |
| } |
| prepare_dynamic(model, qconfig_dynamic_dict) |
| |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkDynamicQuantizedLinear(model.sub2.fc1) |
| self.checkDynamicQuantizedLinear(model.sub2.fc2) |
| self.checkDynamicQuantizedLinear(model.fc3) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict) |
| checkQuantized(model) |
| |
| def test_type_match_rule(self): |
| r"""Test quantization for nested model, top level 'fc3' and |
| 'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized |
| """ |
| model = NestedModel().eval() |
| qconfig_dict = { |
| 'fc3': None, |
| 'sub2.fc1': None, |
| torch.nn.Linear: default_dynamic_qconfig |
| } |
| |
| prepare_dynamic(model, qconfig_dict) |
| test_only_eval_fn(model, self.calib_data) |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkDynamicQuantizedLinear(model.sub1.fc) |
| self.checkLinear(model.fc3) |
| self.checkLinear(model.sub2.fc1) |
| self.checkDynamicQuantizedLinear(model.sub2.fc2) |
| test_only_eval_fn(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dict) |
| checkQuantized(model) |
| |
| def test_quantized_rnn(self): |
| d_in, d_hid = 2, 2 |
| model = LSTMDynamicModel().eval() |
| cell = model.lstm |
| |
| # Replace parameter values s.t. the range of values is exactly |
| # 255, thus we will have 0 quantization error in the quantized |
| # GEMM call. This i s for testing purposes. |
| # |
| # Note that the current implementation does not support |
| # accumulation values outside of the range representable by a |
| # 16 bit integer, instead resulting in a saturated value. We |
| # must take care that in our test we do not end up with a dot |
| # product that overflows the int16 range, e.g. |
| # (255*127+255*127) = 64770. So, we hardcode the test values |
| # here and ensure a mix of signedness. |
| vals = [[100, -155], |
| [100, -155], |
| [-155, 100], |
| [-155, 100], |
| [100, -155], |
| [-155, 100], |
| [-155, 100], |
| [100, -155]] |
| if isinstance(cell, torch.nn.LSTM): |
| num_chunks = 4 |
| vals = vals[:d_hid * num_chunks] |
| cell.weight_ih_l0 = torch.nn.Parameter( |
| torch.tensor(vals, dtype=torch.float), |
| requires_grad=False) |
| cell.weight_hh_l0 = torch.nn.Parameter( |
| torch.tensor(vals, dtype=torch.float), |
| requires_grad=False) |
| |
| ref = copy.deepcopy(cell) |
| |
| qconfig_dynamic_dict = { |
| torch.nn.LSTM: default_dynamic_qconfig, |
| } |
| default_dynamic_module_mapping = { |
| torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM, |
| } |
| model_int8 = quantize_dynamic( |
| model, qconfig_dynamic_dict, default_dynamic_module_mapping |
| ) |
| cell_int8 = model_int8.lstm |
| |
| assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \ |
| 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' |
| |
| niter = 10 |
| x = torch.tensor([[100, -155], |
| [-155, 100], |
| [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) |
| |
| h0_vals = [[-155, 100], |
| [-155, 155], |
| [100, -155]] |
| |
| hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) |
| cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) |
| |
| if isinstance(ref, torch.nn.LSTM): |
| hiddens = (hx, cx) |
| |
| ref_out, ref_hid = ref(x, hiddens) |
| |
| # Compare int8 quantized to unquantized |
| output_int8, final_hiddens_int8 = cell_int8(x, hiddens) |
| |
| torch.testing.assert_allclose(output_int8, ref_out) |
| self.assertEqual(output_int8, ref_out) |
| for out, ref in zip(final_hiddens_int8, ref_hid): |
| torch.testing.assert_allclose(out, ref) |
| |
| @unittest.skipIf( |
| not torch.fbgemm_is_cpu_supported(), |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.", |
| ) |
| class QuantizationAwareTrainingTest(QuantizationTestCase): |
| def test_manual(self): |
| model = ManualLinearQATModel() |
| prepare_qat(model) |
| self.checkObservers(model) |
| test_only_train_fn(model, self.train_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc1), nnq.Linear) |
| self.assertEqual(type(model.fc2), nnq.Linear) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| model = quantize_qat(ManualLinearQATModel(), test_only_train_fn, |
| self.train_data) |
| checkQuantized(model) |
| |
| def test_eval_only_fake_quant(self): |
| r"""Using FakeQuant in evaluation only mode, |
| this is useful for estimating accuracy loss when we quantize the |
| network |
| """ |
| model = ManualLinearQATModel() |
| |
| prepare_qat(model) |
| self.checkObservers(model) |
| |
| model.eval() |
| test_only_eval_fn(model, self.calib_data) |
| |
| def test_conv_linear(self): |
| model = ManualConvLinearQATModel() |
| |
| prepare_qat(model) |
| self.checkObservers(model) |
| |
| test_only_train_fn(model, self.img_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.conv), nnq.Conv2d) |
| self.assertEqual(type(model.fc1), nnq.Linear) |
| self.assertEqual(type(model.fc2), nnq.Linear) |
| test_only_eval_fn(model, self.img_data) |
| self.checkScriptable(model, self.img_data) |
| |
| checkQuantized(model) |
| |
| model = ManualConvLinearQATModel() |
| model = quantize_qat(model, test_only_train_fn, self.img_data) |
| checkQuantized(model) |
| |
| |
| class ScriptabilityTest(QuantizationTestCase): |
| def setUp(self): |
| self.model_under_test = ModForWrapping(quantized=False) |
| self.qmodel_under_test = ModForWrapping(quantized=True) |
| self.qmodel_under_test = self.qmodel_under_test.from_float( |
| self.model_under_test) |
| self.x = torch.rand(10) |
| self.qx = torch.quantize_linear(self.x.to(torch.float), scale=1.0, |
| zero_point=0, dtype=torch.qint32) |
| |
| def test_scriptability_serialization(self): |
| # test serialization of quantized functional modules |
| with tempfile.TemporaryFile() as f: |
| torch.save(self.qmodel_under_test, f) |
| f.seek(0) |
| loaded = torch.load(f) |
| self.assertEqual(self.qmodel_under_test.myadd.zero_point, loaded.myadd.zero_point) |
| state_dict = self.qmodel_under_test.state_dict() |
| self.assertTrue('myadd.zero_point' in state_dict.keys(), |
| 'zero point not in state dict for functional modules') |
| |
| x = torch.rand(10, 1, dtype=torch.float) |
| xq = torch.quantize_linear(x, 1.0, 0, torch.qint8) |
| self.checkScriptable(self.qmodel_under_test, [(xq, xq)], check_save_load=True) |
| self.checkScriptable(self.model_under_test, [(xq.dequantize(), xq.dequantize())], check_save_load=True) |
| |
| @unittest.skipIf(not torch.fbgemm_is_cpu_supported(), |
| 'Quantization requires FBGEMM. FBGEMM does not play' |
| ' well with UBSAN at the moment, so we skip the test if' |
| ' we are in a UBSAN environment.') |
| class FusionTest(QuantizationTestCase): |
| def test_fuse_module_train(self): |
| model = ModelForFusion(default_qat_qconfig).train() |
| fuse_modules(model, [['conv1', 'bn1', 'relu1'], |
| ['sub1.conv', 'sub1.bn']]) |
| self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, |
| "Fused Conv + BN + Relu first layer") |
| self.assertEqual(type(model.bn1), torch.nn.Identity, |
| "Fused Conv + BN + Relu (skipped BN)") |
| self.assertEqual(type(model.relu1), torch.nn.Identity, |
| "Fused Conv + BN + Relu (skipped Relu)") |
| |
| self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, |
| "Fused submodule Conv + BN") |
| self.assertEqual(type(model.sub1.bn), torch.nn.Identity, |
| "Fused submodule Conv + BN (skipped BN)") |
| self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, |
| "Non-fused submodule Conv") |
| self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, |
| "Non-fused submodule ReLU") |
| prepare_qat(model) |
| self.checkObservers(model) |
| |
| def checkQAT(model): |
| self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) |
| self.assertEqual(type(model.bn1), nn.Identity) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) |
| self.assertEqual(type(model.sub1.bn), nn.Identity) |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d) |
| self.assertEqual(type(model.sub2.relu), nn.ReLU) |
| |
| checkQAT(model) |
| test_only_train_fn(model, self.img_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.conv1), nniq.ConvReLU2d) |
| self.assertEqual(type(model.bn1), nn.Identity) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| self.assertEqual(type(model.sub1.conv), nnq.Conv2d) |
| self.assertEqual(type(model.sub1.bn), nn.Identity) |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d) |
| self.assertEqual(type(model.sub2.relu), nn.ReLU) |
| test_only_eval_fn(model, self.img_data) |
| checkQuantized(model) |
| |
| model = ModelForFusion(default_qat_qconfig).train() |
| fuse_modules(model, [['conv1', 'bn1', 'relu1'], |
| ['sub1.conv', 'sub1.bn']]) |
| model = quantize_qat(model, test_only_train_fn, self.img_data) |
| checkQuantized(model) |
| |
| |
| def test_fuse_module_eval(self): |
| model = ModelForFusion(default_qconfig) |
| model.eval() |
| fuse_modules(model, [['conv1', 'bn1', 'relu1'] , |
| ['sub1.conv', 'sub1.bn']]) |
| self.assertEqual(type(model.conv1), nni.ConvReLU2d, |
| "Fused Conv + BN + Relu first layer (BN is folded)") |
| self.assertEqual(type(model.conv1[0]), nn.Conv2d, |
| "Fused Conv + BN + Relu (Conv + folded BN only)") |
| self.assertEqual(type(model.conv1[1]), nn.ReLU, |
| "Fused Conv + BN + Relu second layer (Relu only)") |
| self.assertEqual(type(model.bn1), nn.Identity, |
| "Fused Conv + BN + Relu second layer (Skipped BN)") |
| self.assertEqual(type(model.relu1), nn.Identity, |
| "Fused Conv + BN + Relu second layer (Skipped Relu)") |
| |
| self.assertEqual(type(model.sub1.conv), nn.Conv2d, |
| "Fused submodule Conv + folded BN") |
| self.assertEqual(type(model.sub1.bn), nn.Identity, |
| "Fused submodule (skipped BN)") |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d, |
| "Non-fused submodule Conv") |
| self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, |
| "Non-fused submodule ReLU") |
| |
| prepare(model) |
| self.checkObservers(model) |
| test_only_eval_fn(model, self.img_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.conv1), nniq.ConvReLU2d) |
| self.assertEqual(type(model.bn1), nn.Identity) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| self.assertEqual(type(model.sub1.conv), nnq.Conv2d) |
| self.assertEqual(type(model.sub1.bn), nn.Identity) |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d) |
| self.assertEqual(type(model.sub2.relu), nn.ReLU) |
| test_only_eval_fn(model, self.img_data) |
| checkQuantized(model) |
| |
| model = ModelForFusion(default_qat_qconfig).eval() |
| fuse_modules(model, [['conv1', 'bn1', 'relu1'], |
| ['sub1.conv', 'sub1.bn']]) |
| model = quantize(model, test_only_eval_fn, self.img_data) |
| checkQuantized(model) |
| |
| |
| class ObserverTest(QuantizationTestCase): |
| @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), |
| qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric))) |
| def test_minmax_observer(self, qdtype, qscheme): |
| myobs = MinMaxObserver(dtype=qdtype, qscheme=qscheme) |
| x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) |
| y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) |
| result = myobs(x) |
| result = myobs(y) |
| self.assertEqual(result, y) |
| self.assertEqual(myobs.min_val, 1.0) |
| self.assertEqual(myobs.max_val, 8.0) |
| qparams = myobs.calculate_qparams() |
| if qscheme == torch.per_tensor_symmetric: |
| ref_scale = 0.062745 |
| ref_zero_point = 0 if qdtype is torch.qint8 else 128 |
| else: |
| ref_scale = 0.0313725 |
| ref_zero_point = -128 if qdtype is torch.qint8 else 0 |
| self.assertEqual(qparams[1].item(), ref_zero_point) |
| self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) |
| |
| def test_observer_scriptable(self): |
| obs = torch.quantization.default_observer()() |
| scripted = torch.jit.script(obs) |
| |
| x = torch.rand(3, 4) |
| obs(x) |
| scripted(x) |
| |
| self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) |
| |
| buf = io.BytesIO() |
| torch.jit.save(scripted, buf) |
| buf.seek(0) |
| loaded = torch.jit.load(buf) |
| self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) |
| |
| if __name__ == '__main__': |
| run_tests() |