blob: 42de8fdf8b509c9e99fae531e790e162e6bf58c0 [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn.quantized as nnq
from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
default_eval_fn, QConfig, default_qconfig, default_observer, quantize, \
prepare, convert
from common_utils import TestCase, run_tests
class SingleLayerLinearModel(torch.nn.Module):
def __init__(self):
super(SingleLayerLinearModel, self).__init__()
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = self.fc1(x)
return x
class TwoLayerLinearModel(torch.nn.Module):
def __init__(self):
super(TwoLayerLinearModel, self).__init__()
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
class LinearReluModel(torch.nn.Module):
def __init__(self):
super(LinearReluModel, self).__init__()
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(self.fc(x))
return x
class NestedModel(torch.nn.Module):
def __init__(self):
super(NestedModel, self).__init__()
self.sub1 = LinearReluModel()
self.sub2 = TwoLayerLinearModel()
self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = self.sub1(x)
x = self.sub2(x)
x = self.fc3(x)
return x
class InnerModule(torch.nn.Module):
def __init__(self):
super(InnerModule, self).__init__()
self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
def forward(self, x):
return self.relu(self.fc2(self.relu(self.fc1(x))))
class WrappedModel(torch.nn.Module):
def __init__(self):
super(WrappedModel, self).__init__()
self.qconfig = default_qconfig
self.sub = QuantWrapper(InnerModule())
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
# don't quantize this fc
self.fc.qconfig = None
def forward(self, x):
return self.fc(self.sub(x))
class ManualQuantModel(torch.nn.Module):
r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
"""
def __init__(self):
super(ManualQuantModel, self).__init__()
self.qconfig = default_qconfig
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = self.quant(x)
x = self.fc(x)
return self.dequant(x)
calib_data = [torch.rand(20, 5, dtype=torch.float) for _ in range(20)]
class ModelQuantizeAPITest(TestCase):
def checkNoPrepModules(self, module):
r"""Checks the module does not contain child
modules for quantization prepration, e.g.
quant, dequant and observer
"""
self.assertFalse(hasattr(module, 'quant'))
self.assertFalse(hasattr(module, 'dequant'))
def checkHasPrepModules(self, module):
r"""Checks the module contains child
modules for quantization prepration, e.g.
quant, dequant and observer
"""
self.assertTrue(hasattr(module, 'module'))
self.assertTrue(hasattr(module, 'quant'))
self.assertTrue(hasattr(module, 'dequant'))
def checkObservers(self, module):
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
self.assertTrue(hasattr(module, 'observer'))
for child in module.children():
self.checkObservers(child)
def checkQuantDequant(self, mod):
self.assertEqual(type(mod.quant), nnq.Quantize)
self.assertEqual(type(mod.dequant), nnq.DeQuantize)
def checkQuantizedLinear(self, mod):
self.assertEqual(type(mod.module), nnq.Linear)
self.assertEqual(mod.module.bias.dtype, torch.qint32)
self.checkQuantDequant(mod)
def checkLinear(self, mod):
self.assertEqual(type(mod), torch.nn.Linear)
def test_single_layer(self):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
to nnq.Linear which is the quantized version of the module
"""
model = SingleLayerLinearModel()
qconfig_dict = {
'': default_qconfig
}
model = prepare(model, qconfig_dict)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkObservers(model)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkQuantizedLinear(model.fc1)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(SingleLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
model = TwoLayerLinearModel()
qconfig_dict = {
'fc2': default_qconfig
}
model = prepare(model, qconfig_dict)
self.checkNoPrepModules(model)
self.checkObservers(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkQuantizedLinear(model.fc2)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(TwoLayerLinearModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
model = NestedModel()
qconfig_dict = {
'fc3': default_qconfig,
'sub2.fc1': default_qconfig
}
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkNoPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
model = prepare(model, qconfig_dict)
checkPrepModules(model, True)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.checkQuantizedLinear(model.fc3)
self.checkQuantizedLinear(model.sub2.fc1)
self.checkLinear(model.sub2.fc2)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_nested2(self):
r"""Another test case for quantized, we will quantize all submodules
of submodule sub2, this will include redundant quant/dequant, to
remove them we need to manually call QuantWrapper or insert
QuantStub/DeQuantStub, see `test_quant_dequant_wrapper` and
`test_manual`
"""
model = NestedModel()
qconfig_dict = {
'fc3': default_qconfig,
'sub2': default_qconfig
}
model = prepare(model, qconfig_dict)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkQuantizedLinear(model.sub2.fc1)
self.checkQuantizedLinear(model.sub2.fc2)
self.checkQuantizedLinear(model.fc3)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
model = NestedModel()
custum_options = {
'dtype': torch.quint8,
'qscheme': torch.per_tensor_affine
}
custom_qconfig = QConfig(weight=default_observer(),
activation=default_observer(**custum_options))
qconfig_dict = {
'fc3': default_qconfig,
'sub2': default_qconfig,
'sub2.fc1': custom_qconfig
}
model = prepare(model, qconfig_dict)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkQuantizedLinear(model.sub2.fc1)
self.checkQuantizedLinear(model.sub2.fc2)
self.checkQuantizedLinear(model.fc3)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(NestedModel(), default_eval_fn, calib_data, qconfig_dict)
checkQuantized(model)
def test_quant_wrapper(self):
r"""User need to modify the original code with QuantWrapper,
and call the quantization utility functions.
"""
model = WrappedModel()
# since we didn't provide qconfig_dict, the model is modified inplace
# but we can do `model = prepare(model)` as well
prepare(model)
self.checkObservers(model)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.checkLinear(model.fc)
self.checkQuantDequant(model.sub)
self.assertEqual(type(model.sub.module.fc1), nnq.Linear)
self.assertEqual(type(model.sub.module.fc2), nnq.Linear)
self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(WrappedModel(), default_eval_fn, calib_data, {})
checkQuantized(model)
def test_manual(self):
r"""User inserts QuantStub and DeQuantStub in model code
and call the quantization utility functions.
"""
model = ManualQuantModel()
# propagate the qconfig of parents to children, model is changed
# inplace
prepare(model)
self.checkObservers(model)
default_eval_fn(model, calib_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc), nnq.Linear)
default_eval_fn(model, calib_data)
checkQuantized(model)
# test one line API
model = quantize(ManualQuantModel(), default_eval_fn, calib_data)
checkQuantized(model)
if __name__ == '__main__':
run_tests()