| import unittest |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.quantized as nnq |
| import torch.nn.intrinsic as nni |
| import torch.nn.intrinsic.quantized as nniq |
| import torch.nn.intrinsic.qat as nniqat |
| from torch.nn.utils.rnn import PackedSequence |
| from torch.quantization import \ |
| get_observer_dict, default_weight_observer, \ |
| quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \ |
| quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \ |
| default_dynamic_qconfig, per_channel_dynamic_qconfig, HistogramObserver, MinMaxObserver, \ |
| PerChannelMinMaxObserver, RecordingObserver, MovingAverageMinMaxObserver, \ |
| MovingAveragePerChannelMinMaxObserver, QuantWrapper, default_eval_fn, \ |
| float16_dynamic_qconfig, MinMaxDynamicQuantObserver |
| |
| from torch.quantization import QConfig |
| from torch.quantization import default_histogram_observer |
| from torch.quantization import default_observer |
| from torch.quantization import default_per_channel_weight_observer |
| from torch.quantization import default_per_channel_qconfig |
| from torch.quantization._quantize_script import quantize_script, quantize_dynamic_script |
| |
| from torch.testing._internal.common_utils import TEST_WITH_UBSAN, IS_WINDOWS |
| from torch.testing._internal.common_quantization import QuantizationTestCase, \ |
| AnnotatedSingleLayerLinearModel, SingleLayerLinearModel, \ |
| AnnotatedConvModel, ConvModel, \ |
| AnnotatedConvBnModel, ConvBnModel, \ |
| SkipQuantModel, QuantStubModel, \ |
| ModelForFusion, ModelWithSequentialFusion, ManualLinearQATModel, ManualConvLinearQATModel, \ |
| ModelWithFunctionals, \ |
| test_only_eval_fn, test_only_train_fn, \ |
| prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \ |
| TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel, \ |
| ModelWithNoQconfigPropagation, ModelForFusionWithBias, \ |
| ActivationsTestModel, ActivationsQATTestModel, NormalizationTestModel |
| |
| from torch.testing._internal.common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \ |
| AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel |
| from torch.testing._internal.common_quantization import AnnotatedSkipQuantModel |
| |
| from torch.testing._internal.common_quantized import override_quantized_engine |
| from hypothesis import given |
| from hypothesis import strategies as st |
| import torch.testing._internal.hypothesis_utils as hu |
| hu.assert_deadline_disabled() |
| import io |
| import copy |
| |
| @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.") |
| class TestPostTrainingStatic(QuantizationTestCase): |
| @given(qconfig=st.sampled_from((torch.quantization.default_qconfig, torch.quantization.default_per_channel_qconfig))) |
| def test_single_layer(self, qconfig): |
| r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped |
| to nnq.Linear which is the quantized version of the module |
| """ |
| model = AnnotatedSingleLayerLinearModel() |
| model.qconfig = qconfig |
| model = prepare(model) |
| # Check if observers and quant/dequant nodes are inserted |
| self.checkNoPrepModules(model) |
| self.checkHasPrepModules(model.fc1) |
| self.checkObservers(model) |
| |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model) |
| self.checkHasPrepModules(model.fc1) |
| self.checkWrappedQuantizedLinear(model.fc1) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API - out of place version |
| base = AnnotatedSingleLayerLinearModel() |
| base.qconfig = qconfig |
| keys_before = set(list(base.state_dict().keys())) |
| model = quantize(base, test_only_eval_fn, self.calib_data) |
| checkQuantized(model) |
| keys_after = set(list(base.state_dict().keys())) |
| self.assertEqual(keys_before, keys_after) # simple check that nothing changed |
| |
| # in-place version |
| model = AnnotatedSingleLayerLinearModel() |
| model.qconfig = qconfig |
| quantize(model, test_only_eval_fn, self.calib_data, inplace=True) |
| checkQuantized(model) |
| |
| def test_two_layers(self): |
| r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one |
| `fc2`, and `fc1`is not quantized |
| """ |
| model = AnnotatedTwoLayerLinearModel() |
| model = prepare(model) |
| |
| self.checkNoPrepModules(model) |
| self.checkObservers(model) |
| self.checkNoPrepModules(model.fc1) |
| self.checkHasPrepModules(model.fc2) |
| |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.fc1) |
| self.checkHasPrepModules(model.fc2) |
| self.assertEqual(type(model.fc1), torch.nn.Linear) |
| self.checkWrappedQuantizedLinear(model.fc2) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| def test_nested1(self): |
| r"""Test quantization for nested model, top level 'fc3' and |
| 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized |
| """ |
| model = AnnotatedNestedModel() |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkNoPrepModules(model.sub2) |
| self.checkHasPrepModules(model.sub2.fc1) |
| self.checkNoPrepModules(model.sub2.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| model = prepare(model) |
| checkPrepModules(model, True) |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkLinear(model.sub1.fc) |
| self.checkWrappedQuantizedLinear(model.fc3) |
| self.checkWrappedQuantizedLinear(model.sub2.fc1) |
| self.checkLinear(model.sub2.fc2) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedNestedModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| |
| def test_nested2(self): |
| model = AnnotatedSubNestedModel() |
| model = prepare(model) |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkHasPrepModules(model.sub2) |
| self.checkNoPrepModules(model.sub2.module.fc1) |
| self.checkNoPrepModules(model.sub2.module.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| checkPrepModules(model, True) |
| |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkLinear(model.sub1.fc) |
| self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) |
| self.checkQuantizedLinear(model.sub2.module.fc1) |
| self.checkQuantizedLinear(model.sub2.module.fc2) |
| self.checkWrappedQuantizedLinear(model.fc3) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| def test_nested3(self): |
| r"""More complicated nested test case with child qconfig overrides |
| parent qconfig |
| """ |
| model = AnnotatedCustomConfigNestedModel() |
| model = prepare(model) |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkNoPrepModules(model.sub2) |
| self.checkHasPrepModules(model.sub2.fc1) |
| self.checkHasPrepModules(model.sub2.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| checkPrepModules(model, True) |
| |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkWrappedQuantizedLinear(model.sub2.fc1) |
| self.checkWrappedQuantizedLinear(model.sub2.fc2) |
| self.checkWrappedQuantizedLinear(model.fc3) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model) |
| |
| def test_skip_quant(self): |
| r"""The case when we want to skip quantizing some layers |
| """ |
| |
| model = AnnotatedSkipQuantModel() |
| model = prepare(model) |
| self.checkObservers(model) |
| |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.checkLinear(model.fc) |
| self.checkQuantDequant(model.sub) |
| self.checkQuantizedLinear(model.sub.module.fc1) |
| self.checkQuantizedLinear(model.sub.module.fc2) |
| self.assertEqual(type(model.sub.module.relu1), nnq.ReLU) |
| self.assertEqual(type(model.sub.module.relu2), nnq.ReLU) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(AnnotatedSkipQuantModel(), test_only_eval_fn, self.calib_data) |
| checkQuantized(model) |
| |
| |
| def test_manual(self): |
| r"""User inserts QuantStub and DeQuantStub in model code |
| and call the quantization utility functions. |
| """ |
| model = QuantStubModel() |
| # propagate the qconfig of parents to children, model is changed |
| # inplace |
| model = prepare(model) |
| self.checkObservers(model) |
| |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc), nnq.Linear) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data) |
| checkQuantized(model) |
| |
| @given(qconfig=st.sampled_from((torch.quantization.default_qconfig, torch.quantization.default_per_channel_qconfig))) |
| def test_resnet_base(self, qconfig): |
| r"""Test quantization for bottleneck topology used in resnet/resnext |
| and add coverage for conversion of average pool and float functional |
| """ |
| model = ResNetBase().float().eval() |
| model = QuantWrapper(model) |
| model.qconfig = qconfig |
| fuse_list = ['module.conv1', 'module.bn1', 'module.relu1'] |
| fuse_modules(model, fuse_list, inplace=True) |
| model = prepare(model) |
| self.checkObservers(model) |
| test_only_eval_fn(model, self.img_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d) |
| self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) |
| self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) |
| test_only_eval_fn(model, self.img_data) |
| |
| checkQuantized(model) |
| |
| def test_normalization(self): |
| r""" |
| Test quantization of normalization layers |
| """ |
| model = NormalizationTestModel() |
| model.qconfig = torch.quantization.get_default_qconfig('fbgemm') |
| prepare(model, inplace=True) |
| self.checkObservers(model) |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model.layer_norm) |
| self.assertEqual(type(model.layer_norm), nnq.LayerNorm) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| model_oneline = quantize( |
| NormalizationTestModel(), test_only_eval_fn, self.calib_data) |
| checkQuantized(model) |
| |
| @given(qengine=st.sampled_from(("qnnpack", "fbgemm"))) |
| def test_save_load_state_dict(self, qengine): |
| r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict |
| Load the quantized state_dict for eval and compare results against original model |
| """ |
| if qengine == 'qnnpack': |
| if IS_WINDOWS or TEST_WITH_UBSAN: |
| return |
| with override_quantized_engine(qengine): |
| model = TwoLayerLinearModel() |
| model = torch.quantization.QuantWrapper(model) |
| model.qconfig = torch.quantization.get_default_qconfig(qengine) |
| |
| model = prepare(model) |
| # calibrate |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| x = torch.rand(2, 5, dtype=torch.float) |
| ref = model(x) |
| |
| quant_state_dict = model.state_dict() |
| |
| # Create model again for eval |
| model = TwoLayerLinearModel() |
| model = torch.quantization.QuantWrapper(model) |
| model.qconfig = torch.quantization.get_default_qconfig(qengine) |
| model = prepare(model) |
| model = convert(model) |
| new_state_dict = model.state_dict() |
| |
| # Check to make sure the state dict keys match original model after convert. |
| self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys())) |
| |
| model.load_state_dict(quant_state_dict) |
| |
| out = model(x) |
| self.assertEqual(ref, out) |
| |
| def test_activations(self): |
| r""" |
| Test quantization of activations |
| """ |
| model = ActivationsTestModel() |
| model.qconfig = torch.quantization.get_default_qconfig('fbgemm') |
| prepare(model, inplace=True) |
| self.checkObservers(model) |
| test_only_eval_fn(model, self.calib_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model.hardswish) |
| self.assertEqual(type(model.hardswish), nnq.Hardswish) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn, |
| self.calib_data) |
| checkQuantized(model_oneline) |
| |
| |
| @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.") |
| class TestPostTrainingDynamic(QuantizationTestCase): |
| def test_single_layer(self): |
| r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module, |
| make sure it is swapped to nnqd.Linear which is the quantized version of |
| the module |
| """ |
| for dtype in [torch.qint8, torch.float16]: |
| model = SingleLayerLinearDynamicModel().eval() |
| qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig |
| qconfig_dict = { |
| 'fc1': qconfig |
| } |
| prepare_dynamic(model, qconfig_dict) |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkDynamicQuantizedLinear(model.fc1, dtype) |
| self.checkScriptable(model, self.calib_data, check_save_load=True) |
| |
| checkQuantized(model) |
| |
| # test one line API - out of place version |
| base = SingleLayerLinearDynamicModel() |
| keys_before = set(list(base.state_dict().keys())) |
| model = quantize_dynamic(base, qconfig_dict) |
| checkQuantized(model) |
| keys_after = set(list(base.state_dict().keys())) |
| self.assertEqual(keys_before, keys_after) # simple check that nothing changed |
| |
| # in-place version |
| model = SingleLayerLinearDynamicModel() |
| quantize_dynamic(model, qconfig_dict, inplace=True) |
| checkQuantized(model) |
| |
| # Test set qconfig |
| model = SingleLayerLinearDynamicModel() |
| quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype) |
| checkQuantized(model) |
| |
| def test_two_layers(self): |
| r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one |
| `fc2`, and `fc1`is not quantized |
| """ |
| for dtype in [torch.qint8, torch.float16]: |
| model = TwoLayerLinearModel().eval() |
| qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig |
| qconfig_dict = { |
| 'fc2': qconfig |
| } |
| prepare_dynamic(model, qconfig_dict) |
| |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc1), torch.nn.Linear) |
| self.checkDynamicQuantizedLinear(model.fc2, dtype=dtype) |
| self.checkScriptable(model, self.calib_data, check_save_load=True) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict) |
| checkQuantized(model) |
| |
| # Test set API |
| model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}, dtype=dtype) |
| checkQuantized(model) |
| |
| def test_nested1(self): |
| r"""Test quantization for nested model, top level 'fc3' and |
| 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized |
| """ |
| for dtype in [torch.qint8, torch.float16]: |
| model = NestedModel().eval() |
| qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig |
| qconfig_dict = { |
| 'fc3': qconfig, |
| 'sub2.fc1': qconfig |
| } |
| |
| prepare_dynamic(model, qconfig_dict) |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkLinear(model.sub1.fc) |
| self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) |
| self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) |
| self.checkLinear(model.sub2.fc2) |
| self.checkScriptable(model, self.calib_data, check_save_load=True) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dict) |
| checkQuantized(model) |
| |
| model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}, dtype=dtype) |
| checkQuantized(model) |
| |
| def test_nested2(self): |
| r"""Another test case for quantized, we will quantize all submodules |
| of submodule sub2 |
| """ |
| for dtype in [torch.qint8, torch.float16]: |
| model = NestedModel().eval() |
| qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig |
| qconfig_dict = { |
| 'fc3': qconfig, |
| 'sub2': qconfig |
| } |
| prepare_dynamic(model, qconfig_dict) |
| |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkLinear(model.sub1.fc) |
| self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) |
| self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) |
| self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) |
| self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) |
| self.checkScriptable(model, self.calib_data, check_save_load=True) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype) |
| checkQuantized(model) |
| |
| # Test set API |
| model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}, dtype=dtype) |
| checkQuantized(model) |
| |
| def test_nested3(self): |
| r"""More complicated nested test case with child qconfig overrides |
| parent qconfig |
| """ |
| for dtype in [torch.qint8, torch.float16]: |
| model = NestedModel().eval() |
| qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig |
| qconfig_dynamic_dict = { |
| 'fc3': qconfig, |
| 'sub2': qconfig, |
| 'sub2.fc1': qconfig |
| } |
| prepare_dynamic(model, qconfig_dynamic_dict) |
| |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) |
| self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) |
| self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) |
| self.checkScriptable(model, self.calib_data, check_save_load=True) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict) |
| checkQuantized(model) |
| |
| # Test set API |
| model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}, dtype=dtype) |
| checkQuantized(model) |
| |
| def test_type_match_rule(self): |
| r"""Test quantization for nested model, top level 'fc3' and |
| 'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized |
| """ |
| for dtype in [torch.qint8, torch.float16]: |
| model = NestedModel().eval() |
| qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig |
| qconfig_dict = { |
| 'fc3': None, |
| 'sub2.fc1': None, |
| torch.nn.Linear: qconfig |
| } |
| |
| prepare_dynamic(model, qconfig_dict) |
| test_only_eval_fn(model, self.calib_data) |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=dtype) |
| self.checkLinear(model.fc3) |
| self.checkLinear(model.sub2.fc1) |
| self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data, check_save_load=True) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype) |
| checkQuantized(model) |
| |
| def test_per_channel_quantize(self): |
| r"""Test quantization for per_channel dynamic quantization |
| """ |
| model = NestedModel().eval() |
| qconfig_dict = { |
| torch.nn.Linear: per_channel_dynamic_qconfig |
| } |
| |
| prepare_dynamic(model, qconfig_dict) |
| test_only_eval_fn(model, self.calib_data) |
| convert_dynamic(model) |
| |
| def checkQuantized(model): |
| self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=torch.qint8) |
| self.checkDynamicQuantizedLinear(model.fc3, dtype=torch.qint8) |
| self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=torch.qint8) |
| self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=torch.qint8) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data, check_save_load=True) |
| |
| checkQuantized(model) |
| # test one line API |
| model = quantize_dynamic(NestedModel().eval(), qconfig_dict) |
| checkQuantized(model) |
| |
| @unittest.skip("temporarily disable the test") |
| @given(qengine=st.sampled_from(("fbgemm",))) |
| def test_quantized_rnn(self, qengine): |
| d_in, d_hid = 2, 2 |
| |
| # TODO: qlinear_prepack_fp16 currently doesn't support QNNPACK |
| # re-add "qnnpack" to the engine set when this is supported |
| |
| with override_quantized_engine(qengine): |
| model = LSTMDynamicModel().eval() |
| cell = model.lstm |
| |
| # Replace parameter values s.t. the range of values is exactly |
| # 255, thus we will have 0 quantization error in the quantized |
| # GEMM call. This i s for testing purposes. |
| # |
| # Note that the current implementation does not support |
| # accumulation values outside of the range representable by a |
| # 16 bit integer, instead resulting in a saturated value. We |
| # must take care that in our test we do not end up with a dot |
| # product that overflows the int16 range, e.g. |
| # (255*127+255*127) = 64770. So, we hardcode the test values |
| # here and ensure a mix of signedness. |
| vals = [[100, -155], |
| [100, -155], |
| [-155, 100], |
| [-155, 100], |
| [100, -155], |
| [-155, 100], |
| [-155, 100], |
| [100, -155]] |
| if isinstance(cell, torch.nn.LSTM): |
| num_chunks = 4 |
| vals = vals[:d_hid * num_chunks] |
| cell.weight_ih_l0 = torch.nn.Parameter( |
| torch.tensor(vals, dtype=torch.float), |
| requires_grad=False) |
| cell.weight_hh_l0 = torch.nn.Parameter( |
| torch.tensor(vals, dtype=torch.float), |
| requires_grad=False) |
| |
| ref = copy.deepcopy(cell) |
| |
| model_int8 = quantize_dynamic(model=model, dtype=torch.qint8) |
| model_fp16 = quantize_dynamic(model=model, dtype=torch.float16) |
| |
| # Smoke test extra reprs |
| self.assertTrue('DynamicQuantizedLSTM' in str(model_int8)) |
| self.assertTrue('DynamicQuantizedLSTM' in str(model_fp16)) |
| cell_int8 = model_int8.lstm |
| cell_fp16 = model_fp16.lstm |
| |
| assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \ |
| 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' |
| assert type(cell_fp16) == torch.nn.quantized.dynamic.LSTM, \ |
| 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' |
| |
| niter = 10 |
| x = torch.tensor([[100, -155], |
| [-155, 100], |
| [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) |
| |
| h0_vals = [[-155, 100], |
| [-155, 155], |
| [100, -155]] |
| |
| hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) |
| cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) |
| |
| if isinstance(ref, torch.nn.LSTM): |
| hiddens = (hx, cx) |
| |
| ref_out, ref_hid = ref(x, hiddens) |
| |
| # Compare int8 quantized to unquantized |
| output_int8, final_hiddens_int8 = cell_int8(x, hiddens) |
| |
| torch.testing.assert_allclose(output_int8, ref_out) |
| self.assertEqual(output_int8, ref_out) |
| for out_val, ref_val in zip(final_hiddens_int8, ref_hid): |
| torch.testing.assert_allclose(out_val, ref_val) |
| |
| class ScriptWrapper(torch.nn.Module): |
| def __init__(self, cell): |
| super(ScriptWrapper, self).__init__() |
| self.cell = cell |
| |
| def forward(self, x, hiddens): |
| # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) |
| # -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] |
| return self.cell(x, hiddens) |
| |
| # TODO: TorchScript overloads don't work without this wrapper |
| cell_script = torch.jit.script(ScriptWrapper(cell_int8)) |
| out_script, hid_script = cell_script(x, hiddens) |
| self.assertEqual(len(out_script), len(ref_out)) |
| for out_val, ref_val in zip(out_script, ref_out): |
| torch.testing.assert_allclose(out_val, ref_val) |
| |
| # Test save/load |
| b = io.BytesIO() |
| torch.jit.save(cell_script, b) |
| b.seek(0) |
| loaded = torch.jit.load(b) |
| out_loaded, hid_loaded = loaded(x, hiddens) |
| for loaded_val, ref_val in zip(out_loaded, ref_out): |
| torch.testing.assert_allclose(loaded_val, ref_val) |
| |
| # Compare fp16 quantized to unquantized |
| output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) |
| |
| torch.testing.assert_allclose(output_fp16, ref_out) |
| self.assertEqual(output_fp16, ref_out) |
| for out, ref_val in zip(final_hiddens_fp16, ref_hid): |
| torch.testing.assert_allclose(out, ref_val) |
| |
| # Test tracing |
| # TODO: TorchScript overloads don't work without this wrapper |
| cell_trace = torch.jit.trace(ScriptWrapper(cell_int8), (x, (hx, cx))) |
| out_script, hid_script = cell_trace(x, hiddens) |
| for out_val, ref_val in zip(out_script, ref_out): |
| torch.testing.assert_allclose(out_val, ref_val) |
| |
| # print(cell_trace.code) |
| |
| # Test save/load |
| b = io.BytesIO() |
| torch.jit.save(cell_trace, b) |
| b.seek(0) |
| loaded = torch.jit.load(b) |
| out_loaded, hid_loaded = loaded(x, hiddens) |
| for loaded_val, ref_val in zip(out_loaded, ref_out): |
| torch.testing.assert_allclose(loaded_val, ref_val) |
| |
| # Compare fp16 quantized to unquantized |
| output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) |
| |
| torch.testing.assert_allclose(output_fp16, ref_out) |
| self.assertEqual(output_fp16, ref_out) |
| for out, ref_val in zip(final_hiddens_fp16, ref_hid): |
| torch.testing.assert_allclose(out, ref_val) |
| |
| class ScriptWrapperPacked(torch.nn.Module): |
| def __init__(self, cell): |
| super(ScriptWrapperPacked, self).__init__() |
| self.cell = cell |
| |
| def forward(self, |
| x, # type: PackedSequence |
| hiddens # type: Tuple[torch.Tensor, torch.Tensor] |
| ): |
| # type: (...) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]] |
| return self.cell(x, hiddens) |
| |
| cell_packed = torch.jit.script(ScriptWrapperPacked(cell_int8)) |
| packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2])) |
| ref_out_packed, ref_hid_packed = ref(packed_input, hiddens) |
| output_packed, hiddens_packed = cell_packed(packed_input, hiddens) |
| |
| for packed_val, ref_val in zip(output_packed, ref_out_packed): |
| if isinstance(packed_val, torch.Tensor): |
| torch.testing.assert_allclose(packed_val, ref_val) |
| else: |
| self.assertEqual(packed_val, ref_val) |
| |
| # Test save/load |
| b = io.BytesIO() |
| torch.jit.save(cell_packed, b) |
| b.seek(0) |
| loaded_packed = torch.jit.load(b) |
| out_loaded_packed, hid_loaded_packed = loaded_packed(packed_input, hiddens) |
| for packed_val, ref_val in zip(out_loaded_packed, ref_out_packed): |
| if isinstance(packed_val, torch.Tensor): |
| torch.testing.assert_allclose(packed_val, ref_val) |
| else: |
| self.assertEqual(packed_val, ref_val) |
| |
| # Test default instantiation |
| seq_len = 128 |
| batch = 16 |
| input_size = 3 |
| hidden_size = 7 |
| num_layers = 2 |
| bias = True |
| bidirectional = False |
| |
| x = torch.rand(seq_len, batch, input_size) |
| h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size) |
| c = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size) |
| |
| dtype = torch.qint8 |
| |
| cell_dq = torch.nn.quantized.dynamic.LSTM(input_size=input_size, |
| hidden_size=hidden_size, |
| num_layers=num_layers, |
| bias=bias, |
| batch_first=False, |
| dropout=0.0, |
| bidirectional=bidirectional, |
| dtype=dtype) |
| |
| y, (h, c) = cell_dq(x, (h, c)) |
| |
| |
| @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.") |
| class TestQuantizationAwareTraining(QuantizationTestCase): |
| def test_manual(self): |
| model = ManualLinearQATModel() |
| model = prepare_qat(model) |
| self.checkObservers(model) |
| test_only_train_fn(model, self.train_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc1), nnq.Linear) |
| self.assertEqual(type(model.fc2), nnq.Linear) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| model = quantize_qat(ManualLinearQATModel(), test_only_train_fn, |
| self.train_data) |
| checkQuantized(model) |
| |
| def test_activations(self): |
| model = ActivationsQATTestModel() |
| model = prepare_qat(model) |
| |
| self.assertEqual(type(model.fc1), torch.nn.qat.modules.Linear) |
| self.assertEqual(type(model.hardswish), torch.nn.qat.modules.Hardswish) |
| |
| self.checkObservers(model) |
| test_only_train_fn(model, self.train_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc1), nnq.Linear) |
| self.assertEqual(type(model.hardswish), nnq.Hardswish) |
| test_only_eval_fn(model, self.calib_data) |
| self.checkScriptable(model, self.calib_data) |
| |
| checkQuantized(model) |
| |
| model = quantize_qat(ActivationsQATTestModel(), test_only_train_fn, |
| self.train_data) |
| checkQuantized(model) |
| |
| def test_eval_only_fake_quant(self): |
| r"""Using FakeQuant in evaluation only mode, |
| this is useful for estimating accuracy loss when we quantize the |
| network |
| """ |
| model = ManualLinearQATModel() |
| |
| model = prepare_qat(model) |
| self.checkObservers(model) |
| |
| model.eval() |
| test_only_eval_fn(model, self.calib_data) |
| |
| def test_conv_linear(self): |
| model = ManualConvLinearQATModel() |
| |
| model = prepare_qat(model) |
| self.checkObservers(model) |
| |
| test_only_train_fn(model, self.img_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.conv), nnq.Conv2d) |
| self.assertEqual(type(model.fc1), nnq.Linear) |
| self.assertEqual(type(model.fc2), nnq.Linear) |
| test_only_eval_fn(model, self.img_data) |
| self.checkScriptable(model, self.img_data) |
| |
| checkQuantized(model) |
| |
| model = ManualConvLinearQATModel() |
| model = quantize_qat(model, test_only_train_fn, self.img_data) |
| checkQuantized(model) |
| |
| @given(qengine=st.sampled_from(("qnnpack", "fbgemm"))) |
| def test_train_save_load_eval(self, qengine): |
| r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict |
| During eval, we first call prepare_qat and conver on the model and then load the state_dict |
| and compare results against original model |
| """ |
| if qengine == 'qnnpack': |
| if IS_WINDOWS or TEST_WITH_UBSAN: |
| return |
| with override_quantized_engine(qengine): |
| model = TwoLayerLinearModel() |
| model = torch.quantization.QuantWrapper(model) |
| model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) |
| model = prepare_qat(model) |
| |
| fq_state_dict = model.state_dict() |
| |
| test_only_train_fn(model, self.train_data) |
| model = convert(model) |
| |
| quant_state_dict = model.state_dict() |
| |
| x = torch.rand(2, 5, dtype=torch.float) |
| ref = model(x) |
| |
| # Create model again for eval. Check result using quantized state_dict |
| model = TwoLayerLinearModel() |
| model = torch.quantization.QuantWrapper(model) |
| model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) |
| torch.quantization.prepare_qat(model, inplace=True) |
| new_state_dict = model.state_dict() |
| |
| # Check to make sure the model after prepare_qat has the same state_dict as original. |
| self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) |
| |
| torch.quantization.convert(model, inplace=True) |
| model.eval() |
| model.load_state_dict(quant_state_dict) |
| out = model(x) |
| self.assertEqual(ref, out) |
| |
| # Check model created using prepare has same state dict as quantized state_dict |
| model = TwoLayerLinearModel() |
| model.eval() |
| model = torch.quantization.QuantWrapper(model) |
| model.qconfig = torch.quantization.get_default_qconfig(qengine) |
| torch.quantization.prepare(model, inplace=True) |
| torch.quantization.convert(model, inplace=True) |
| self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) |
| model.eval() |
| model.load_state_dict(quant_state_dict) |
| out = model(x) |
| self.assertEqual(ref, out) |
| |
| @unittest.skipUnless( |
| 'fbgemm' in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.", |
| ) |
| class TestGraphModePostTrainingStatic(QuantizationTestCase): |
| def test_single_linear(self): |
| r"""Compare the result of quantizing single linear layer in |
| eager mode and graph mode |
| """ |
| # eager mode |
| annotated_linear_model = AnnotatedSingleLayerLinearModel().eval() |
| linear_model = SingleLayerLinearModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach()) |
| linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach()) |
| model_eager = quantize(annotated_linear_model, test_only_eval_fn, |
| self.calib_data) |
| |
| qconfig_dict = {'': default_qconfig} |
| model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(linear_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_script( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| def test_observer_with_ignored_function(self): |
| r"""Test observers with ignored function and make sure it works in |
| graph mode |
| """ |
| # eager mode |
| annotated_linear_model = AnnotatedSingleLayerLinearModel().eval() |
| for qconfig in [ |
| QConfig( |
| activation=default_observer, |
| weight=default_weight_observer), |
| QConfig( |
| activation=default_histogram_observer, |
| weight=default_weight_observer), |
| QConfig( |
| activation=default_observer, |
| weight=default_per_channel_weight_observer), |
| ]: |
| annotated_linear_model.qconfig = qconfig |
| linear_model = SingleLayerLinearModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach()) |
| linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach()) |
| model_eager = quantize(annotated_linear_model, test_only_eval_fn, |
| self.calib_data) |
| |
| qconfig_dict = {'': qconfig} |
| model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(linear_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_script( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| def test_conv(self): |
| r"""Compare the result of quantizing conv layer in |
| eager mode and graph mode |
| """ |
| # eager mode |
| annotated_conv_model = AnnotatedConvModel().eval() |
| conv_model = ConvModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach()) |
| model_eager = quantize(annotated_conv_model, default_eval_fn, |
| self.img_data) |
| qconfig_dict = {'': default_qconfig} |
| model_traced = torch.jit.trace(conv_model, self.img_data[0][0]) |
| model_script = torch.jit.script(conv_model) |
| result_eager = model_eager(self.img_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_script( |
| model_under_test, |
| qconfig_dict, |
| default_eval_fn, |
| [self.img_data], |
| inplace=False) |
| self.assertEqual(model_quantized(self.img_data[0][0]), result_eager) |
| |
| @unittest.skip("This doesn't work right now, re-enable after fold_convbn is fixed") |
| def test_conv_bn(self): |
| r"""Compare the result of quantizing conv + bn layer in |
| eager mode and graph mode |
| """ |
| # eager mode |
| conv_model = AnnotatedConvBnModel().eval() |
| conv_model_to_script = ConvBnModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| conv_model_to_script.conv.weight = torch.nn.Parameter(conv_model.conv.weight.detach()) |
| fuse_modules(conv_model, ['conv', 'bn'], inplace=True) |
| model_eager = quantize(conv_model, default_eval_fn, |
| self.img_data) |
| qconfig_dict = { |
| '': default_qconfig |
| } |
| model_script = quantize_script( |
| torch.jit.script(conv_model_to_script), |
| qconfig_dict, |
| default_eval_fn, |
| [self.img_data], |
| inplace=False) |
| result_eager = model_eager(self.img_data[0][0]) |
| result_script = model_script(self.img_data[0][0]) |
| self.assertEqual(result_eager, result_script) |
| |
| def test_nested(self): |
| # Eager mode |
| eager_model = AnnotatedNestedModel().eval() |
| |
| # Graph mode |
| script_model = NestedModel().eval() |
| # Copy weights for eager_model |
| script_model.sub1.fc.weight = torch.nn.Parameter(eager_model.sub1.fc.weight.detach()) |
| script_model.sub1.fc.bias = torch.nn.Parameter(eager_model.sub1.fc.bias.detach()) |
| script_model.sub2.fc1.weight = torch.nn.Parameter(eager_model.sub2.fc1.module.weight.detach()) |
| script_model.sub2.fc1.bias = torch.nn.Parameter(eager_model.sub2.fc1.module.bias.detach()) |
| script_model.sub2.fc2.weight = torch.nn.Parameter(eager_model.sub2.fc2.weight.detach()) |
| script_model.sub2.fc2.bias = torch.nn.Parameter(eager_model.sub2.fc2.bias.detach()) |
| script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach()) |
| script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach()) |
| |
| model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data) |
| qconfig_dict = { |
| 'sub2.fc1': default_per_channel_qconfig, |
| 'fc3': default_qconfig |
| } |
| model_traced = torch.jit.trace(script_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(script_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_script( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| def test_skip_quant(self): |
| """ Test None qconfig |
| """ |
| # Eager mode |
| eager_model = AnnotatedSkipQuantModel().eval() |
| |
| # Graph mode |
| script_model = SkipQuantModel().eval() |
| # Copy weights for eager_model |
| script_model.sub.fc1.weight = torch.nn.Parameter(eager_model.sub.module.fc1.weight.detach()) |
| script_model.sub.fc1.bias = torch.nn.Parameter(eager_model.sub.module.fc1.bias.detach()) |
| script_model.sub.fc2.weight = torch.nn.Parameter(eager_model.sub.module.fc2.weight.detach()) |
| script_model.sub.fc2.bias = torch.nn.Parameter(eager_model.sub.module.fc2.bias.detach()) |
| script_model.fc.weight = torch.nn.Parameter(eager_model.fc.weight.detach()) |
| script_model.fc.bias = torch.nn.Parameter(eager_model.fc.bias.detach()) |
| |
| model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data) |
| qconfig_dict = { |
| '': default_qconfig, |
| 'fc': None |
| } |
| model_traced = torch.jit.trace(script_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(script_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_script( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| def test_single_linear_dynamic(self): |
| r"""Compare the result of dynamic quantization of single linear layer in |
| eager mode and graph mode. |
| """ |
| # eager mode |
| annotated_linear_model = AnnotatedSingleLayerLinearModel().eval() |
| linear_model = SingleLayerLinearModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach()) |
| linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach()) |
| qconfig_dict = {'': default_dynamic_qconfig} |
| model_eager = quantize_dynamic(annotated_linear_model, qconfig_dict) |
| |
| model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(linear_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_dynamic_script( |
| model_under_test, |
| qconfig_dict, |
| [self.calib_data[0][0]]) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| # Check to make sure choose_qparams->quant->dequant->linear is numerically |
| # equivalent to the final quantized model. |
| model_fake_quantized = quantize_dynamic_script( |
| model_under_test, |
| qconfig_dict, |
| [self.calib_data[0][0]], |
| debug=True) |
| self.assertEqual(model_fake_quantized(self.calib_data[0][0]), result_eager) |
| |
| |
| class TestFunctionalModule(QuantizationTestCase): |
| # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out |
| @given(train_mode=st.booleans()) |
| def test_functional_module(self, train_mode): |
| model = ModelWithFunctionals() |
| x = torch.rand(10, 1, dtype=torch.float) |
| xq = torch.quantize_per_tensor(x, 0.01, 30, torch.quint8) |
| self.checkScriptable(model, [(x, x)], check_save_load=True) |
| if train_mode: |
| model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') |
| model = prepare_qat(model) |
| else: |
| model.qconfig = torch.quantization.get_default_qconfig('qnnpack') |
| model = prepare(model) |
| # Check if observers and quant/dequant nodes are inserted |
| self.checkNoPrepModules(model) |
| self.checkObservers(model) |
| # Calibrate |
| model(xq.dequantize()) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model) |
| self.assertEqual(type(model.myadd), torch.nn.quantized.QFunctional) |
| self.assertEqual(type(model.mycat), torch.nn.quantized.QFunctional) |
| self.assertEqual(type(model.myadd_relu), torch.nn.quantized.QFunctional) |
| |
| checkQuantized(model) |
| self.checkScriptable(model, [(xq, xq)], check_save_load=True) |
| |
| @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.") |
| class TestFusion(QuantizationTestCase): |
| def test_fuse_module_train(self): |
| model = ModelForFusion(default_qat_qconfig).train() |
| # Test step by step fusion |
| model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) |
| model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) |
| self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, |
| "Fused Conv + BN + Relu first layer") |
| self.assertEqual(type(model.bn1), torch.nn.Identity, |
| "Fused Conv + BN + Relu (skipped BN)") |
| self.assertEqual(type(model.relu1), torch.nn.Identity, |
| "Fused Conv + BN + Relu (skipped Relu)") |
| |
| self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, |
| "Fused submodule Conv + BN") |
| self.assertEqual(type(model.sub1.bn), torch.nn.Identity, |
| "Fused submodule Conv + BN (skipped BN)") |
| self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, |
| "Non-fused submodule Conv") |
| self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, |
| "Non-fused submodule ReLU") |
| model = prepare_qat(model) |
| self.checkObservers(model) |
| |
| def checkQAT(model): |
| self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) |
| self.assertEqual(type(model.bn1), nn.Identity) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) |
| self.assertEqual(type(model.sub1.bn), nn.Identity) |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d) |
| self.assertEqual(type(model.sub2.relu), nn.ReLU) |
| |
| checkQAT(model) |
| test_only_train_fn(model, self.img_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.conv1), nniq.ConvReLU2d) |
| self.assertEqual(type(model.bn1), nn.Identity) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| self.assertEqual(type(model.sub1.conv), nnq.Conv2d) |
| self.assertEqual(type(model.sub1.bn), nn.Identity) |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d) |
| self.assertEqual(type(model.sub2.relu), nn.ReLU) |
| test_only_eval_fn(model, self.img_data) |
| checkQuantized(model) |
| |
| model = ModelForFusion(default_qat_qconfig).train() |
| model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], |
| ['sub1.conv', 'sub1.bn']]) |
| model = quantize_qat(model, test_only_train_fn, self.img_data) |
| checkQuantized(model) |
| |
| |
| def test_fuse_module_eval(self): |
| model = ModelForFusion(default_qconfig) |
| model.eval() |
| model = fuse_modules(model, [['conv1', 'bn1', 'relu1'] , |
| ['conv2', 'relu2'], |
| ['bn2', 'relu3'], |
| ['sub1.conv', 'sub1.bn']]) |
| self.assertEqual(type(model.conv1), nni.ConvReLU2d, |
| "Fused Conv + BN + Relu first layer (BN is folded)") |
| self.assertEqual(type(model.conv1[0]), nn.Conv2d, |
| "Fused Conv + BN + Relu (Conv + folded BN only)") |
| self.assertEqual(type(model.conv1[1]), nn.ReLU, |
| "Fused Conv + BN + Relu second layer (Relu only)") |
| self.assertEqual(type(model.bn1), nn.Identity, |
| "Fused Conv + BN + Relu second layer (Skipped BN)") |
| self.assertEqual(type(model.relu1), nn.Identity, |
| "Fused Conv + BN + Relu second layer (Skipped Relu)") |
| self.assertEqual(type(model.conv2), nni.ConvReLU3d, |
| "Fused Conv + BN + Relu first layer (BN is folded)") |
| self.assertEqual(type(model.bn2), nni.BNReLU3d, |
| "Fused BN + Relu first layer (Relu is folded))") |
| self.assertEqual(type(model.relu3), nn.Identity, |
| "Fused BN + Relu second layer (Skipped Relu)") |
| self.assertEqual(type(model.conv2[0]), nn.Conv3d, |
| "Fused Conv + BN + Relu (Conv + folded BN only)") |
| self.assertEqual(type(model.conv2[1]), nn.ReLU, |
| "Fused Conv + BN + Relu second layer (Relu only)") |
| self.assertEqual(type(model.relu2), nn.Identity, |
| "Fused Conv + BN + Relu second layer (Skipped Relu)") |
| |
| self.assertEqual(type(model.sub1.conv), nn.Conv2d, |
| "Fused submodule Conv + folded BN") |
| self.assertEqual(type(model.sub1.bn), nn.Identity, |
| "Fused submodule (skipped BN)") |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d, |
| "Non-fused submodule Conv") |
| self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, |
| "Non-fused submodule ReLU") |
| |
| model = prepare(model) |
| self.checkObservers(model) |
| test_only_eval_fn(model, self.img_data) |
| model = convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.conv1), nniq.ConvReLU2d) |
| self.assertEqual(type(model.bn1), nn.Identity) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| self.assertEqual(type(model.sub1.conv), nnq.Conv2d) |
| self.assertEqual(type(model.sub1.bn), nn.Identity) |
| self.assertEqual(type(model.sub2.conv), nn.Conv2d) |
| self.assertEqual(type(model.sub2.relu), nn.ReLU) |
| self.assertEqual(type(model.bn2), nniq.BNReLU3d) |
| test_only_eval_fn(model, self.img_data) |
| checkQuantized(model) |
| |
| model = ModelForFusion(default_qconfig).eval() |
| model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], |
| ['conv2', 'relu2'], |
| ['bn2', 'relu3'], |
| ['sub1.conv', 'sub1.bn']]) |
| model = quantize(model, test_only_eval_fn, self.img_data) |
| checkQuantized(model) |
| |
| def test_fusion_sequential_model_train(self): |
| model = ModelWithSequentialFusion().train() |
| model.to(torch.float) |
| fuse_modules(model, [['conv1', 'relu1'] , |
| ['features.0.0', 'features.0.1', 'features.0.2'], |
| ['features.1.0', 'features.1.1', 'features.1.2'], |
| ['features.2.0', 'features.2.1', 'features.2.2'], |
| ['classifier.0', 'classifier.1']], inplace=True) |
| self.assertEqual(type(model.conv1), nni.ConvReLU2d, |
| "Fused Conv + Relu: nni.ConvReLU2d") |
| self.assertEqual(type(model.conv1[0]), nn.Conv2d, |
| "Fused Conv + Relu: Conv2d") |
| self.assertEqual(type(model.conv1[1]), nn.ReLU, |
| "Fused Conv + Relu: Relu") |
| self.assertEqual(type(model.relu1), nn.Identity, |
| "Fused Conv + Relu: Identity") |
| for i in range(3): |
| self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, |
| "Fused submodule Conv + folded BN") |
| self.assertEqual(type(model.features[i][1]), nn.Identity, |
| "Fused submodule (skipped BN)") |
| self.assertEqual(type(model.features[i][2]), nn.Identity, |
| "Non-fused submodule Conv") |
| self.assertEqual(type(model.classifier[0]), nni.LinearReLU) |
| self.assertEqual(type(model.classifier[1]), nn.Identity) |
| model.qconfig = default_qat_qconfig |
| prepare_qat(model, inplace=True) |
| self.checkObservers(model) |
| model(self.img_data[0][0]) |
| |
| |
| def checkQAT(model): |
| self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| for i in range(3): |
| self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, |
| "Fused submodule Conv + folded BN") |
| self.assertEqual(type(model.features[i][1]), nn.Identity, |
| "Fused submodule (skipped BN)") |
| self.assertEqual(type(model.features[i][2]), nn.Identity, |
| "Non-fused submodule Conv") |
| self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) |
| self.assertEqual(type(model.classifier[1]), nn.Identity) |
| |
| checkQAT(model) |
| model(self.img_data[1][0]) |
| convert(model, inplace=True) |
| model(self.img_data[1][0]) |
| self.checkModelWithSequentialQuantized(model) |
| |
| def test_fusion_sequential_model_eval(self): |
| model = ModelWithSequentialFusion().eval() |
| model.to(torch.float) |
| fuse_modules(model, [['conv1', 'relu1'] , |
| ['features.0.0', 'features.0.1', 'features.0.2'], |
| ['features.1.0', 'features.1.1', 'features.1.2'], |
| ['features.2.0', 'features.2.1', 'features.2.2'], |
| ['classifier.0', 'classifier.1']], inplace=True) |
| self.assertEqual(type(model.conv1), nni.ConvReLU2d, |
| "Fused Conv + Relu: nni.ConvReLU2d") |
| self.assertEqual(type(model.conv1[0]), nn.Conv2d, |
| "Fused Conv + Relu: Conv2d") |
| self.assertEqual(type(model.conv1[1]), nn.ReLU, |
| "Fused Conv + Relu: Relu") |
| self.assertEqual(type(model.relu1), nn.Identity, |
| "Fused Conv + Relu: Identity") |
| for i in range(3): |
| self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d, |
| "Fused submodule Conv + folded BN") |
| self.assertEqual(type(model.features[i][1]), nn.Identity, |
| "Fused submodule (skipped BN)") |
| self.assertEqual(type(model.features[i][2]), nn.Identity, |
| "Non-fused submodule Conv") |
| self.assertEqual(type(model.classifier[0]), nni.LinearReLU) |
| self.assertEqual(type(model.classifier[1]), nn.Identity) |
| model.qconfig = default_qconfig |
| prepare(model, inplace=True) |
| self.checkObservers(model) |
| model(self.img_data[0][0]) |
| convert(model, inplace=True) |
| model(self.img_data[1][0]) |
| self.checkModelWithSequentialQuantized(model) |
| |
| def checkModelWithSequentialQuantized(self, model): |
| self.assertEqual(type(model.conv1), nniq.ConvReLU2d) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| for i in range(3): |
| self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d) |
| self.assertEqual(type(model.features[i][1]), nn.Identity) |
| self.assertEqual(type(model.features[i][2]), nn.Identity) |
| self.assertEqual(type(model.classifier[0]), nniq.LinearReLU) |
| self.assertEqual(type(model.classifier[1]), nn.Identity) |
| |
| def test_fusion_conv_with_bias(self): |
| model = ModelForFusionWithBias().train() |
| # output with no fusion. |
| out_ref = model(self.img_data[0][0]) |
| |
| model.qconfig = QConfig(activation=torch.nn.Identity, |
| weight=torch.nn.Identity) |
| model = fuse_modules(model, [["conv1", "bn1", "relu1"], |
| ["conv2", "bn2"]]) |
| prep_model = prepare_qat(model, inplace=False) |
| # output with fusion but no observers. |
| out_fused = prep_model(self.img_data[0][0]) |
| self.assertEqual(out_ref, out_fused) |
| |
| model.qconfig = default_qat_qconfig |
| prepare_qat(model, inplace=True) |
| |
| model(self.img_data[0][0]) |
| |
| def checkQAT(model): |
| self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) |
| self.assertEqual(type(model.bn1), nn.Identity) |
| self.assertEqual(type(model.relu1), nn.Identity) |
| self.assertEqual(type(model.conv2), nniqat.ConvBn2d) |
| self.assertEqual(type(model.bn2), nn.Identity) |
| |
| checkQAT(model) |
| |
| class TestObserver(QuantizationTestCase): |
| @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), |
| qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), |
| reduce_range=st.booleans()) |
| def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): |
| # reduce_range cannot be true for symmetric quantization with uint8 |
| if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric: |
| reduce_range = False |
| ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range), |
| MovingAverageMinMaxObserver(averaging_constant=0.5, |
| dtype=qdtype, |
| qscheme=qscheme, |
| reduce_range=reduce_range)] |
| for myobs in ObserverList: |
| # Calculate Qparams should return with a warning for observers with no data |
| qparams = myobs.calculate_qparams() |
| if type(myobs) == MinMaxObserver: |
| x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) |
| y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) |
| else: |
| # Moving average of min/max for x and y matches that of |
| # extreme values for x/y used for minmax observer |
| x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) |
| y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0]) |
| |
| result = myobs(x) |
| result = myobs(y) |
| self.assertEqual(result, y) |
| self.assertEqual(myobs.min_val, 1.0) |
| self.assertEqual(myobs.max_val, 8.0) |
| qparams = myobs.calculate_qparams() |
| if reduce_range: |
| if qscheme == torch.per_tensor_symmetric: |
| ref_scale = 0.062745 * 255 / 127 |
| ref_zero_point = 0 if qdtype is torch.qint8 else 128 |
| else: |
| ref_scale = 0.0313725 * 255 / 127 |
| ref_zero_point = -64 if qdtype is torch.qint8 else 0 |
| else: |
| if qscheme == torch.per_tensor_symmetric: |
| ref_scale = 0.062745 |
| ref_zero_point = 0 if qdtype is torch.qint8 else 128 |
| else: |
| ref_scale = 0.0313725 |
| ref_zero_point = -128 if qdtype is torch.qint8 else 0 |
| self.assertEqual(qparams[1].item(), ref_zero_point) |
| self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) |
| state_dict = myobs.state_dict() |
| b = io.BytesIO() |
| torch.save(state_dict, b) |
| b.seek(0) |
| loaded_dict = torch.load(b) |
| for key in state_dict: |
| self.assertEqual(state_dict[key], loaded_dict[key]) |
| loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) |
| loaded_obs.load_state_dict(loaded_dict) |
| loaded_qparams = loaded_obs.calculate_qparams() |
| self.assertEqual(myobs.min_val, loaded_obs.min_val) |
| self.assertEqual(myobs.max_val, loaded_obs.max_val) |
| self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) |
| |
| |
| @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4, |
| min_side=1, max_side=10), |
| qparams=hu.qparams()), |
| reduce_range=st.booleans()) |
| def test_per_tensor_dynamic_quant_observers(self, X, reduce_range): |
| |
| X, (scale, zero_point, torch_type) = X |
| x = torch.from_numpy(X) |
| |
| obs = MinMaxDynamicQuantObserver(dtype=torch.quint8, reduce_range=reduce_range) |
| |
| result = obs(x) |
| qparams = obs.calculate_qparams() |
| ref = torch._choose_qparams_per_tensor(x, reduce_range) |
| |
| self.assertEqual(ref[0], qparams[0]) |
| self.assertEqual(ref[1], qparams[1]) |
| |
| def test_tensor_list_observer(self): |
| from torch.quantization.observer import _MinMaxTensorListObserver |
| x = [torch.tensor([1.0, 2.5, 3.5]), |
| torch.tensor([2.0, 4.5, 3.5]), |
| torch.tensor([4.0, 2.5, 3.5]), ] |
| obs = _MinMaxTensorListObserver() |
| obs(x) |
| qparams = obs.calculate_qparams() |
| ref_min_val = [] |
| ref_max_val = [] |
| ref_qparams = [] |
| for i in x: |
| obs_ref = MinMaxObserver() |
| obs_ref(i) |
| ref_min_val.append(obs_ref.min_val) |
| ref_max_val.append(obs_ref.max_val) |
| ref_qparams.append(obs_ref.calculate_qparams()) |
| for i in range(len(x)): |
| self.assertEqual(obs.min_val[i], ref_min_val[i]) |
| self.assertEqual(obs.max_val[i], ref_max_val[i]) |
| self.assertEqual(qparams[0][i], ref_qparams[i][0]) |
| self.assertEqual(qparams[1][i], ref_qparams[i][1]) |
| |
| @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), |
| qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)), |
| ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans()) |
| def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): |
| # reduce_range cannot be true for symmetric quantization with uint8 |
| if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric: |
| reduce_range = False |
| ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range, |
| ch_axis=ch_axis, |
| dtype=qdtype, |
| qscheme=qscheme), |
| MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5, |
| reduce_range=reduce_range, |
| ch_axis=ch_axis, |
| dtype=qdtype, |
| qscheme=qscheme)] |
| |
| for myobs in ObserverList: |
| # Calculate qparams should work for empty observers |
| qparams = myobs.calculate_qparams() |
| x = torch.tensor( |
| [ |
| [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]], |
| [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], |
| ] |
| ) |
| if type(myobs) == MovingAveragePerChannelMinMaxObserver: |
| # Scaling the input tensor to model change in min/max values |
| # across batches |
| result = myobs(0.5 * x) |
| result = myobs(1.5 * x) |
| self.assertEqual(result, 1.5 * x) |
| else: |
| result = myobs(x) |
| self.assertEqual(result, x) |
| |
| qparams = myobs.calculate_qparams() |
| ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]] |
| ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]] |
| per_channel_symmetric_ref_scales = [ |
| [0.04705882, 0.06274509], |
| [0.03921569, 0.0627451], |
| [0.04705882, 0.0627451], |
| [0.05490196, 0.0627451], |
| ] |
| per_channel_affine_ref_scales = [ |
| [0.02352941, 0.04705882], |
| [0.03529412, 0.03137255], |
| [0.03921569, 0.03137255], |
| [0.04313726, 0.04313726], |
| ] |
| per_channel_affine_qint8_zp = [ |
| [-128, -43], |
| [-15, -128], |
| [-26, -128], |
| [-35, -58], |
| ] |
| per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] |
| |
| self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis]) |
| self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis]) |
| if qscheme == torch.per_channel_symmetric: |
| ref_scales = per_channel_symmetric_ref_scales[ch_axis] |
| ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] |
| else: |
| ref_scales = per_channel_affine_ref_scales[ch_axis] |
| ref_zero_points = ( |
| per_channel_affine_qint8_zp[ch_axis] |
| if qdtype is torch.qint8 |
| else per_channel_affine_quint8_zp[ch_axis] |
| ) |
| |
| if reduce_range: |
| ref_scales = [s * 255 / 127 for s in ref_scales] |
| ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] |
| |
| self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype))) |
| self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) |
| |
| # Test for serializability |
| state_dict = myobs.state_dict() |
| b = io.BytesIO() |
| torch.save(state_dict, b) |
| b.seek(0) |
| loaded_dict = torch.load(b) |
| for key in state_dict: |
| self.assertEqual(state_dict[key], loaded_dict[key]) |
| loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) |
| loaded_obs.load_state_dict(loaded_dict) |
| loaded_qparams = loaded_obs.calculate_qparams() |
| self.assertEqual(myobs.min_vals, loaded_obs.min_vals) |
| self.assertEqual(myobs.max_vals, loaded_obs.max_vals) |
| self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) |
| |
| def test_observer_scriptable(self): |
| obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()] |
| for obs in obs_list: |
| scripted = torch.jit.script(obs) |
| |
| x = torch.rand(3, 4) |
| obs(x) |
| scripted(x) |
| self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) |
| |
| buf = io.BytesIO() |
| torch.jit.save(scripted, buf) |
| buf.seek(0) |
| loaded = torch.jit.load(buf) |
| self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) |
| |
| # Check TensorListObserver |
| from torch.quantization.observer import _MinMaxTensorListObserver |
| obs = _MinMaxTensorListObserver() |
| scripted = torch.jit.script(obs) |
| x = [torch.rand(3, 4), torch.rand(4, 5)] |
| obs(x) |
| scripted(x) |
| self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) |
| |
| def test_no_qconfig_propagation(self): |
| model = ModelWithNoQconfigPropagation() |
| model.qconfig = torch.quantization.default_qconfig |
| |
| model = prepare(model) |
| self.assertTrue(hasattr(model.fc1, 'qconfig'), |
| "QConfig is expected to propagate") |
| self.assertFalse(hasattr(model.no_quant_module, 'qconfig'), |
| "QConfig is expected to NOT propagate") |
| |
| |
| @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.") |
| class TestRecordHistogramObserver(QuantizationTestCase): |
| def test_record_observer(self): |
| model = AnnotatedSingleLayerLinearModel() |
| model.qconfig = default_debug_qconfig |
| model = prepare(model) |
| # run the evaluation and dump all tensors |
| test_only_eval_fn(model, self.calib_data) |
| test_only_eval_fn(model, self.calib_data) |
| observer_dict = {} |
| get_observer_dict(model, observer_dict) |
| |
| self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(), |
| 'observer is not recorded in the dict') |
| self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data)) |
| self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0], model(self.calib_data[0][0])) |
| |
| @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), |
| qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric))) |
| def test_observer_scriptable(self, qdtype, qscheme): |
| obs = RecordingObserver(dtype=qdtype, qscheme=qscheme) |
| scripted = torch.jit.script(obs) |
| |
| x = torch.rand(3, 4) |
| obs(x) |
| scripted(x) |
| self.assertTrue(torch.equal(obs.get_tensor_value()[0], scripted.get_tensor_value()[0])) |
| buf = io.BytesIO() |
| torch.jit.save(scripted, buf) |
| buf.seek(0) |
| loaded = torch.jit.load(buf) |
| self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0])) |
| |
| @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), |
| qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), |
| reduce_range=st.booleans()) |
| def test_histogram_observer(self, qdtype, qscheme, reduce_range): |
| myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) |
| # Calculate qparams should work for empty observers |
| qparams = myobs.calculate_qparams() |
| x = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True) |
| y = torch.tensor([5.0, 6.0, 7.0, 8.0]) |
| out_x = myobs(x) |
| self.assertTrue(out_x.requires_grad) |
| myobs(y) |
| self.assertEqual(myobs.min_val, 2.0) |
| self.assertEqual(myobs.max_val, 8.0) |
| self.assertEqual(myobs.histogram, [2., 3., 3.]) |
| |
| qparams = myobs.calculate_qparams() |
| |
| if reduce_range: |
| if qscheme == torch.per_tensor_symmetric: |
| ref_scale = 0.0470588 * 255 / 127 |
| ref_zero_point = 0 if qdtype is torch.qint8 else 128 |
| else: |
| ref_scale = 0.0235294 * 255 / 127 |
| ref_zero_point = -64 if qdtype is torch.qint8 else 0 |
| else: |
| if qscheme == torch.per_tensor_symmetric: |
| ref_scale = 0.0470588 |
| ref_zero_point = 0 if qdtype is torch.qint8 else 128 |
| else: |
| ref_scale = 0.0235294 |
| ref_zero_point = -128 if qdtype is torch.qint8 else 0 |
| |
| self.assertEqual(qparams[1].item(), ref_zero_point) |
| self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) |
| # Test for serializability |
| state_dict = myobs.state_dict() |
| b = io.BytesIO() |
| torch.save(state_dict, b) |
| b.seek(0) |
| loaded_dict = torch.load(b) |
| for key in state_dict: |
| self.assertEqual(state_dict[key], loaded_dict[key]) |
| loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) |
| loaded_obs.load_state_dict(loaded_dict) |
| loaded_qparams = loaded_obs.calculate_qparams() |
| self.assertEqual(myobs.min_val, loaded_obs.min_val) |
| self.assertEqual(myobs.max_val, loaded_obs.max_val) |
| self.assertEqual(myobs.histogram, loaded_obs.histogram) |
| self.assertEqual(myobs.bins, loaded_obs.bins) |
| self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) |
| |
| def test_histogram_observer_one_sided(self): |
| myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) |
| x = torch.tensor([0.0, 0.3, 1.2, 1.7]) |
| y = torch.tensor([0.1, 1.3, 2.0, 2.7]) |
| myobs(x) |
| myobs(y) |
| self.assertEqual(myobs.min_val, 0) |
| qparams = myobs.calculate_qparams() |
| self.assertEqual(qparams[1].item(), 0) |
| |
| |
| |
| if __name__ == '__main__': |
| run_tests() |