blob: 79905c2572c1b84d1f147f16566f71ad5c00f8ce [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn.quantized as nnq
from torch.quantization import \
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules
from common_utils import run_tests
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
SkipQuantModel, QuantStubModel, \
ModForFusion, ManualLinearQATModel, ManualConvLinearQATModel, test_only_eval_fn, test_only_train_fn
from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel
class PostTrainingQuantTest(QuantizationTestCase):
def test_single_layer(self):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
to nnq.Linear which is the quantized version of the module
"""
model = SingleLayerLinearModel()
model = prepare(model)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkWrappedQuantizedLinear(model.fc1)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(SingleLayerLinearModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
model = AnnotatedTwoLayerLinearModel()
model = prepare(model)
self.checkNoPrepModules(model)
self.checkObservers(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkWrappedQuantizedLinear(model.fc2)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
model = AnnotatedNestedModel()
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkNoPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
model = prepare(model)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.checkWrappedQuantizedLinear(model.fc3)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkLinear(model.sub2.fc2)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedNestedModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_nested2(self):
model = AnnotatedSubNestedModel()
model = prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkHasPrepModules(model.sub2)
self.checkNoPrepModules(model.sub2.module.fc1)
self.checkNoPrepModules(model.sub2.module.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkQuantizedLinear(model.sub2.module.fc1)
self.checkQuantizedLinear(model.sub2.module.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
model = AnnotatedCustomConfigNestedModel()
model = prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkWrappedQuantizedLinear(model.sub2.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_skip_quant(self):
r"""The case when we want to skip quantizing some layers
"""
model = SkipQuantModel()
prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.checkLinear(model.fc)
self.checkQuantDequant(model.sub)
self.checkQuantizedLinear(model.sub.module.fc1)
self.checkQuantizedLinear(model.sub.module.fc2)
self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(SkipQuantModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_manual(self):
r"""User inserts QuantStub and DeQuantStub in model code
and call the quantization utility functions.
"""
model = QuantStubModel()
# propagate the qconfig of parents to children, model is changed
# inplace
prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
class QuantizationAwareTrainingTest(QuantizationTestCase):
def test_manual(self):
model = ManualLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
model = quantize_qat(ManualLinearQATModel(), test_only_train_fn, self.train_data)
checkQuantized(model)
def test_eval_only_fake_quant(self):
r"""Using FakeQuant in evaluation only mode,
this is useful for estimating accuracy loss when we quantize the
network
"""
model = ManualLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
model.eval()
test_only_eval_fn(model, self.calib_data)
def test_conv_linear(self):
model = ManualConvLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.img_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv), nnq.Conv2d)
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, self.img_data)
checkQuantized(model)
class FusionTest(QuantizationTestCase):
def test_fuse_module_train(self):
import torch.nn._intrinsic.modules.fused as torch_fused
testMod = ModForFusion()
testMod.train()
fuse_modules(testMod, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(testMod.conv1), torch_fused.ConvBnReLU2d,
"Fused Conv + BN + Relu first layer")
self.assertEqual(type(testMod.bn1), torch.nn.Identity,
"Fused Conv + BN + Relu (skipped BN)")
self.assertEqual(type(testMod.relu1), torch.nn.Identity,
"Fused Conv + BN + Relu (skipped Relu)")
self.assertEqual(type(testMod.sub1.conv), torch_fused.ConvBn2d,
"Fused submodule Conv + BN")
self.assertEqual(type(testMod.sub1.bn), torch.nn.Identity,
"Fused submodule Conv + BN (skipped BN)")
self.assertEqual(type(testMod.sub2.conv), torch.nn.Conv2d,
"Non-fused submodule Conv")
self.assertEqual(type(testMod.sub2.bn), torch.nn.BatchNorm2d,
"Non-fused submodule BN")
def test_fuse_module_eval(self):
import torch.nn._intrinsic.modules.fused as torch_fused
testMod = ModForFusion()
testMod.eval()
fuse_modules(testMod, [['conv1', 'bn1', 'relu1'] ,
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(testMod.conv1), torch_fused.ConvReLU2d,
"Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(testMod.conv1[0]), torch.nn.Conv2d,
"Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(testMod.conv1[1]), torch.nn.ReLU,
"Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(testMod.bn1), torch.nn.Identity,
"Fused Conv + BN + Relu second layer (Skipped BN)")
self.assertEqual(type(testMod.relu1), torch.nn.Identity,
"Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(testMod.sub1.conv), torch.nn.Conv2d,
"Fused submodule Conv + folded BN")
self.assertEqual(type(testMod.sub1.bn), torch.nn.Identity,
"Fused submodule (skipped BN)")
self.assertEqual(type(testMod.sub2.conv), torch.nn.Conv2d,
"Non-fused submodule Conv")
self.assertEqual(type(testMod.sub2.bn), torch.nn.BatchNorm2d,
"Non-fused submodule BN")
if __name__ == '__main__':
run_tests()