| from __future__ import absolute_import, division, print_function, unicode_literals |
| import torch |
| import torch.nn.quantized as nnq |
| from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ |
| default_eval_fn, QConfig, default_qconfig, default_observer, quantize, \ |
| prepare, convert |
| |
| from common_utils import TestCase, run_tests |
| |
| class SingleLayerLinearModel(torch.nn.Module): |
| def __init__(self): |
| super(SingleLayerLinearModel, self).__init__() |
| self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| return x |
| |
| class TwoLayerLinearModel(torch.nn.Module): |
| def __init__(self): |
| super(TwoLayerLinearModel, self).__init__() |
| self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) |
| self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.fc2(x) |
| return x |
| |
| class LinearReluModel(torch.nn.Module): |
| def __init__(self): |
| super(LinearReluModel, self).__init__() |
| self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.relu(self.fc(x)) |
| return x |
| |
| class NestedModel(torch.nn.Module): |
| def __init__(self): |
| super(NestedModel, self).__init__() |
| self.sub1 = LinearReluModel() |
| self.sub2 = TwoLayerLinearModel() |
| self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) |
| |
| def forward(self, x): |
| x = self.sub1(x) |
| x = self.sub2(x) |
| x = self.fc3(x) |
| return x |
| |
| class InnerModule(torch.nn.Module): |
| def __init__(self): |
| super(InnerModule, self).__init__() |
| self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) |
| self.relu = torch.nn.ReLU() |
| self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) |
| |
| def forward(self, x): |
| return self.relu(self.fc2(self.relu(self.fc1(x)))) |
| |
| class WrappedModel(torch.nn.Module): |
| def __init__(self): |
| super(WrappedModel, self).__init__() |
| self.qconfig = default_qconfig |
| self.sub = QuantWrapper(InnerModule()) |
| self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) |
| # don't quantize this fc |
| self.fc.qconfig = None |
| |
| def forward(self, x): |
| return self.fc(self.sub(x)) |
| |
| class ManualQuantModel(torch.nn.Module): |
| r"""A Module with manually inserted `QuantStub` and `DeQuantStub` |
| """ |
| def __init__(self): |
| super(ManualQuantModel, self).__init__() |
| self.qconfig = default_qconfig |
| self.quant = QuantStub() |
| self.dequant = DeQuantStub() |
| self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) |
| |
| def forward(self, x): |
| x = self.quant(x) |
| x = self.fc(x) |
| return self.dequant(x) |
| |
| calib_data = [torch.rand(20, 5, dtype=torch.float) for _ in range(20)] |
| |
| class ModelQuantizeAPITest(TestCase): |
| |
| def checkNoPrepModules(self, module): |
| r"""Checks the module does not contain child |
| modules for quantization prepration, e.g. |
| quant, dequant and observer |
| """ |
| self.assertFalse(hasattr(module, 'quant')) |
| self.assertFalse(hasattr(module, 'dequant')) |
| |
| def checkHasPrepModules(self, module): |
| r"""Checks the module contains child |
| modules for quantization prepration, e.g. |
| quant, dequant and observer |
| """ |
| self.assertTrue(hasattr(module, 'module')) |
| self.assertTrue(hasattr(module, 'quant')) |
| self.assertTrue(hasattr(module, 'dequant')) |
| |
| def checkObservers(self, module): |
| if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0: |
| self.assertTrue(hasattr(module, 'observer')) |
| for child in module.children(): |
| self.checkObservers(child) |
| |
| def checkQuantDequant(self, mod): |
| self.assertEqual(type(mod.quant), nnq.Quantize) |
| self.assertEqual(type(mod.dequant), nnq.DeQuantize) |
| |
| def checkQuantizedLinear(self, mod): |
| self.assertEqual(type(mod.module), nnq.Linear) |
| self.assertEqual(mod.module.bias.dtype, torch.qint32) |
| self.checkQuantDequant(mod) |
| |
| def checkLinear(self, mod): |
| self.assertEqual(type(mod), torch.nn.Linear) |
| |
| def test_single_layer(self): |
| r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped |
| to nnq.Linear which is the quantized version of the module |
| """ |
| model = SingleLayerLinearModel() |
| qconfig_dict = { |
| '': default_qconfig |
| } |
| model = prepare(model, qconfig_dict) |
| # Check if observers and quant/dequant nodes are inserted |
| self.checkNoPrepModules(model) |
| self.checkHasPrepModules(model.fc1) |
| self.checkObservers(model) |
| |
| default_eval_fn(model, calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model) |
| self.checkHasPrepModules(model.fc1) |
| self.checkQuantizedLinear(model.fc1) |
| default_eval_fn(model, calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(SingleLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict) |
| checkQuantized(model) |
| |
| def test_two_layers(self): |
| r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one |
| `fc2`, and `fc1`is not quantized |
| """ |
| model = TwoLayerLinearModel() |
| qconfig_dict = { |
| 'fc2': default_qconfig |
| } |
| model = prepare(model, qconfig_dict) |
| |
| self.checkNoPrepModules(model) |
| self.checkObservers(model) |
| self.checkNoPrepModules(model.fc1) |
| self.checkHasPrepModules(model.fc2) |
| |
| default_eval_fn(model, calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.fc1) |
| self.checkHasPrepModules(model.fc2) |
| self.assertEqual(type(model.fc1), torch.nn.Linear) |
| self.checkQuantizedLinear(model.fc2) |
| default_eval_fn(model, calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(TwoLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict) |
| checkQuantized(model) |
| |
| def test_nested1(self): |
| r"""Test quantization for nested model, top level 'fc3' and |
| 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized |
| """ |
| model = NestedModel() |
| qconfig_dict = { |
| 'fc3': default_qconfig, |
| 'sub2.fc1': default_qconfig |
| } |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkNoPrepModules(model.sub2) |
| self.checkHasPrepModules(model.sub2.fc1) |
| self.checkNoPrepModules(model.sub2.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| model = prepare(model, qconfig_dict) |
| checkPrepModules(model, True) |
| default_eval_fn(model, calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkLinear(model.sub1.fc) |
| self.checkQuantizedLinear(model.fc3) |
| self.checkQuantizedLinear(model.sub2.fc1) |
| self.checkLinear(model.sub2.fc2) |
| default_eval_fn(model, calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict) |
| checkQuantized(model) |
| |
| |
| def test_nested2(self): |
| r"""Another test case for quantized, we will quantize all submodules |
| of submodule sub2, this will include redundant quant/dequant, to |
| remove them we need to manually call QuantWrapper or insert |
| QuantStub/DeQuantStub, see `test_quant_dequant_wrapper` and |
| `test_manual` |
| """ |
| model = NestedModel() |
| qconfig_dict = { |
| 'fc3': default_qconfig, |
| 'sub2': default_qconfig |
| } |
| model = prepare(model, qconfig_dict) |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkNoPrepModules(model.sub2) |
| self.checkHasPrepModules(model.sub2.fc1) |
| self.checkHasPrepModules(model.sub2.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| checkPrepModules(model, True) |
| |
| default_eval_fn(model, calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkLinear(model.sub1.fc) |
| self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) |
| self.checkQuantizedLinear(model.sub2.fc1) |
| self.checkQuantizedLinear(model.sub2.fc2) |
| self.checkQuantizedLinear(model.fc3) |
| default_eval_fn(model, calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict) |
| checkQuantized(model) |
| |
| def test_nested3(self): |
| r"""More complicated nested test case with child qconfig overrides |
| parent qconfig |
| """ |
| model = NestedModel() |
| custum_options = { |
| 'dtype': torch.quint8, |
| 'qscheme': torch.per_tensor_affine |
| } |
| custom_qconfig = QConfig(weight=default_observer(), |
| activation=default_observer(**custum_options)) |
| qconfig_dict = { |
| 'fc3': default_qconfig, |
| 'sub2': default_qconfig, |
| 'sub2.fc1': custom_qconfig |
| } |
| model = prepare(model, qconfig_dict) |
| |
| def checkPrepModules(model, before_calib=False): |
| if before_calib: |
| self.checkObservers(model) |
| self.checkNoPrepModules(model) |
| self.checkNoPrepModules(model.sub1) |
| self.checkNoPrepModules(model.sub1.fc) |
| self.checkNoPrepModules(model.sub1.relu) |
| self.checkNoPrepModules(model.sub2) |
| self.checkHasPrepModules(model.sub2.fc1) |
| self.checkHasPrepModules(model.sub2.fc2) |
| self.checkHasPrepModules(model.fc3) |
| |
| checkPrepModules(model, True) |
| |
| default_eval_fn(model, calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| checkPrepModules(model) |
| self.checkQuantizedLinear(model.sub2.fc1) |
| self.checkQuantizedLinear(model.sub2.fc2) |
| self.checkQuantizedLinear(model.fc3) |
| default_eval_fn(model, calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict) |
| checkQuantized(model) |
| |
| def test_quant_wrapper(self): |
| r"""User need to modify the original code with QuantWrapper, |
| and call the quantization utility functions. |
| """ |
| model = WrappedModel() |
| |
| # since we didn't provide qconfig_dict, the model is modified inplace |
| # but we can do `model = prepare(model)` as well |
| prepare(model) |
| self.checkObservers(model) |
| |
| default_eval_fn(model, calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.checkLinear(model.fc) |
| self.checkQuantDequant(model.sub) |
| self.assertEqual(type(model.sub.module.fc1), nnq.Linear) |
| self.assertEqual(type(model.sub.module.fc2), nnq.Linear) |
| self.assertEqual(type(model.sub.module.relu), nnq.ReLU) |
| default_eval_fn(model, calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(WrappedModel(), default_eval_fn, calib_data, {}) |
| checkQuantized(model) |
| |
| |
| def test_manual(self): |
| r"""User inserts QuantStub and DeQuantStub in model code |
| and call the quantization utility functions. |
| """ |
| model = ManualQuantModel() |
| # propagate the qconfig of parents to children, model is changed |
| # inplace |
| prepare(model) |
| self.checkObservers(model) |
| |
| default_eval_fn(model, calib_data) |
| convert(model) |
| |
| def checkQuantized(model): |
| self.assertEqual(type(model.fc), nnq.Linear) |
| default_eval_fn(model, calib_data) |
| |
| checkQuantized(model) |
| |
| # test one line API |
| model = quantize(ManualQuantModel(), default_eval_fn, calib_data) |
| checkQuantized(model) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |