blob: 0b66e99053bd5ba40b182e8a69dd991af8b85b6a [file] [log] [blame]
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
from torch.nn.utils.rnn import PackedSequence
from torch.quantization import (
quantize,
prepare,
convert,
prepare_qat,
quantize_qat,
fuse_modules,
quantize_dynamic,
QuantWrapper,
QConfig,
default_qconfig,
default_qat_qconfig,
default_dynamic_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
AnnotatedSingleLayerLinearModel,
QuantStubModel,
ModelForFusion,
ModelWithSequentialFusion,
ManualLinearQATModel,
ManualConvLinearQATModel,
ModelWithFunctionals,
ModelMultipleOps,
ModelMultipleOpsNoAvgPool,
SingleLayerLinearDynamicModel,
TwoLayerLinearModel,
NestedModel,
ResNetBase,
RNNDynamicModel,
RNNCellDynamicModel,
ModelForFusionWithBias,
ActivationsTestModel,
ActivationsQATTestModel,
NormalizationTestModel,
NormalizationQATTestModel,
test_only_eval_fn,
test_only_train_fn,
prepare_dynamic,
convert_dynamic,
skipIfNoFBGEMM,
)
# annotated models
from torch.testing._internal.common_quantization import (
AnnotatedTwoLayerLinearModel,
AnnotatedNestedModel,
AnnotatedSubNestedModel,
AnnotatedCustomConfigNestedModel,
AnnotatedSkipQuantModel,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
override_qengines,
)
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
class TestPostTrainingStatic(QuantizationTestCase):
def test_single_layer(self):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
to nnq.Linear which is the quantized version of the module
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
qconfig = torch.quantization.get_default_qconfig(qengine)
model = AnnotatedSingleLayerLinearModel(qengine)
model.qconfig = qconfig
model = prepare(model)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkWrappedQuantizedLinear(model.fc1)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API - out of place version
base = AnnotatedSingleLayerLinearModel(qengine)
base.qconfig = qconfig
keys_before = set(list(base.state_dict().keys()))
model = quantize(base, test_only_eval_fn, self.calib_data)
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
# in-place version
model = AnnotatedSingleLayerLinearModel(qengine)
model.qconfig = qconfig
quantize(model, test_only_eval_fn, self.calib_data, inplace=True)
checkQuantized(model)
@skipIfNoFBGEMM
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
with override_quantized_engine('fbgemm'):
model = AnnotatedTwoLayerLinearModel()
model = prepare(model)
self.checkNoPrepModules(model)
self.checkObservers(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkWrappedQuantizedLinear(model.fc2)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = AnnotatedNestedModel(qengine)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkNoPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
model = prepare(model)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.checkWrappedQuantizedLinear(model.fc3)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkLinear(model.sub2.fc2)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedNestedModel(qengine), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
@skipIfNoFBGEMM
def test_nested2(self):
model = AnnotatedSubNestedModel()
model = prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkHasPrepModules(model.sub2)
self.checkNoPrepModules(model.sub2.module.fc1)
self.checkNoPrepModules(model.sub2.module.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkQuantizedLinear(model.sub2.module.fc1)
self.checkQuantizedLinear(model.sub2.module.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = AnnotatedCustomConfigNestedModel()
model = prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkWrappedQuantizedLinear(model.sub2.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_skip_quant(self):
r"""The case when we want to skip quantizing some layers
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = AnnotatedSkipQuantModel(qengine)
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkLinear(model.fc)
self.checkQuantDequant(model.sub)
self.checkQuantizedLinear(model.sub.module.fc1)
self.checkQuantizedLinear(model.sub.module.fc2)
self.assertEqual(type(model.sub.module.relu1), nnq.ReLU)
self.assertEqual(type(model.sub.module.relu2), nnq.ReLU)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, self.calib_data)
checkQuantized(model)
@skipIfNoFBGEMM
def test_manual(self):
r"""User inserts QuantStub and DeQuantStub in model code
and call the quantization utility functions.
"""
model = QuantStubModel()
# propagate the qconfig of parents to children, model is changed
# inplace
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_resnet_base(self):
r"""Test quantization for bottleneck topology used in resnet/resnext
and add coverage for conversion of average pool and float functional
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
qconfig = torch.quantization.get_default_qconfig(qengine)
model = ResNetBase().float().eval()
model = QuantWrapper(model)
model.qconfig = qconfig
fuse_list = ['module.conv1', 'module.bn1', 'module.relu1']
fuse_modules(model, fuse_list, inplace=True)
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.img_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d)
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
@skipIfNoFBGEMM
def test_normalization(self):
r"""
Test quantization of normalization layers
"""
model = NormalizationTestModel()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
prepare(model, inplace=True)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model.layer_norm)
self.checkNoPrepModules(model.group_norm)
self.checkNoPrepModules(model.instance_norm1d)
self.checkNoPrepModules(model.instance_norm2d)
self.checkNoPrepModules(model.instance_norm3d)
self.assertEqual(type(model.layer_norm), nnq.LayerNorm)
self.assertEqual(type(model.group_norm), nnq.GroupNorm)
self.assertEqual(type(model.instance_norm1d), nnq.InstanceNorm1d)
self.assertEqual(type(model.instance_norm2d), nnq.InstanceNorm2d)
self.assertEqual(type(model.instance_norm3d), nnq.InstanceNorm3d)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model_oneline = quantize(
NormalizationTestModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_save_load_state_dict(self):
r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict
Load the quantized state_dict for eval and compare results against original model
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
model = prepare(model)
# calibrate
test_only_eval_fn(model, self.calib_data)
model = convert(model)
x = torch.rand(2, 5, dtype=torch.float)
ref = model(x)
quant_state_dict = model.state_dict()
# Create model again for eval
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
model = prepare(model)
model = convert(model)
new_state_dict = model.state_dict()
# Check to make sure the state dict keys match original model after convert.
self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys()))
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
@skipIfNoFBGEMM
def test_activations(self):
r"""
Test quantization of activations
"""
model = ActivationsTestModel()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
prepare(model, inplace=True)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model.hardswish)
self.assertEqual(type(model.hardswish), nnq.Hardswish)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model_oneline)
@skipIfNoFBGEMM
class TestPostTrainingDynamic(QuantizationTestCase):
def test_single_layer(self):
r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
make sure it is swapped to nnqd.Linear which is the quantized version of
the module
"""
for dtype in [torch.qint8, torch.float16]:
model = SingleLayerLinearDynamicModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc1': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.fc1, dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API - out of place version
base = SingleLayerLinearDynamicModel()
keys_before = set(list(base.state_dict().keys()))
model = quantize_dynamic(base, qconfig_dict)
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
# in-place version
model = SingleLayerLinearDynamicModel()
quantize_dynamic(model, qconfig_dict, inplace=True)
checkQuantized(model)
# Test set qconfig
model = SingleLayerLinearDynamicModel()
quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
for dtype in [torch.qint8, torch.float16]:
model = TwoLayerLinearModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc2': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkDynamicQuantizedLinear(model.fc2, dtype=dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict)
checkQuantized(model)
# Test set API
model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}, dtype=dtype)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc3': qconfig,
'sub2.fc1': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkLinear(model.sub1.fc)
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
self.checkLinear(model.sub2.fc2)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}, dtype=dtype)
checkQuantized(model)
def test_nested2(self):
r"""Another test case for quantized, we will quantize all submodules
of submodule sub2
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc3': qconfig,
'sub2': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
checkQuantized(model)
# Test set API
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}, dtype=dtype)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dynamic_dict = {
'fc3': qconfig,
'sub2': qconfig,
'sub2.fc1': qconfig
}
prepare_dynamic(model, qconfig_dynamic_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict)
checkQuantized(model)
# Test set API
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}, dtype=dtype)
checkQuantized(model)
def test_type_match_rule(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc3': None,
'sub2.fc1': None,
torch.nn.Linear: qconfig
}
prepare_dynamic(model, qconfig_dict)
test_only_eval_fn(model, self.calib_data)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=dtype)
self.checkLinear(model.fc3)
self.checkLinear(model.sub2.fc1)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
checkQuantized(model)
def test_per_channel_linear_quantize(self):
r"""Test quantization for per_channel dynamic quantization
"""
model = NestedModel().eval()
qconfig_dict = {
torch.nn.Linear: per_channel_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dict)
test_only_eval_fn(model, self.calib_data)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=torch.qint8)
self.checkDynamicQuantizedLinear(model.fc3, dtype=torch.qint8)
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=torch.qint8)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=torch.qint8)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
@given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]),
dtype=st.sampled_from([torch.qint8, torch.float16]))
def test_quantized_rnn(self, qconfig, dtype):
r"""Test dynamic quantization, scriptability and serialization for dynamic quantized lstm modules on int8 and fp16
"""
model = RNNDynamicModel('LSTM').eval()
niter = 10
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
qconfig_dict = {
torch.nn.LSTM : qconfig
}
if dtype == torch.float16:
model_quantized = quantize_dynamic(model=model, dtype=dtype)
else:
model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype)
# Smoke test extra reprs
self.assertTrue('DynamicQuantizedLSTM' in str(model_quantized))
self.checkDynamicQuantizedModule(model_quantized.mod, torch.nn.quantized.dynamic.LSTM, dtype)
self.checkScriptable(model_quantized, [(x, x)], check_save_load=True)
class ScriptWrapperPacked(torch.nn.Module):
def __init__(self, cell):
super(ScriptWrapperPacked, self).__init__()
self.cell = cell
def forward(self,
x # type: PackedSequence
):
# type: (...) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]
return self.cell(x)
packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2]))
model_with_packed_input = ScriptWrapperPacked(model_quantized.mod)
scripted = torch.jit.script(model_with_packed_input)
# We cannot trace with input dtype being a packed sequence
self._checkScriptable(model_with_packed_input, scripted, [(packed_input, x)], True)
@given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]),
dtype=st.sampled_from([torch.qint8, torch.float16]))
def test_quantized_rnn_cell(self, qconfig, dtype):
r"""Test dynamic quantization, scriptability and serialization for dynamic quantized rnn cell modules on int8 and fp16
"""
qconfig_dict = {
torch.nn.LSTMCell : qconfig,
torch.nn.GRUCell : qconfig,
torch.nn.RNNCell : qconfig
}
for module_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']:
model = RNNCellDynamicModel(module_type).eval()
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float)
if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16:
continue
# fp16 dynamic quant is not supported for qnnpack
if dtype == torch.float16:
model_quantized = quantize_dynamic(model=model, dtype=dtype)
else:
model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype)
def checkQuantized(model, module_type):
mod_type_map = {'LSTMCell': torch.nn.quantized.dynamic.LSTMCell,
'GRUCell': torch.nn.quantized.dynamic.GRUCell,
'RNNTanh': torch.nn.quantized.dynamic.RNNCell,
'RNNReLU': torch.nn.quantized.dynamic.RNNCell}
mod_repr_map = {'LSTMCell': 'DynamicQuantizedLSTMCell',
'GRUCell': 'DynamicQuantizedGRUCell',
'RNNTanh': 'DynamicQuantizedRNNCell',
'RNNReLU': 'DynamicQuantizedRNNCell'}
self.assertTrue(mod_repr_map[module_type] in str(model_quantized))
self.checkDynamicQuantizedModule(model_quantized.mod, mod_type_map[module_type], dtype)
# Smoke test extra reprs
checkQuantized(model_quantized, module_type)
self.checkScriptable(model_quantized, [(x, x)], check_save_load=True)
class TestQuantizationAwareTraining(QuantizationTestCase):
def test_manual(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualLinearQATModel(qengine)
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
self.train_data)
checkQuantized(model)
def test_activations(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ActivationsQATTestModel(qengine)
model = prepare_qat(model)
self.assertEqual(type(model.fc1), torch.nn.qat.modules.Linear)
self.assertEqual(type(model.hardswish), torch.nn.qat.modules.Hardswish)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.hardswish), nnq.Hardswish)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model = quantize_qat(ActivationsQATTestModel(qengine), test_only_train_fn,
self.train_data)
checkQuantized(model)
@override_qengines
def test_normalization(self):
qengine = torch.backends.quantized.engine
model = NormalizationQATTestModel(qengine)
model = prepare_qat(model)
self.assertEqual(type(model.fc1), torch.nn.qat.modules.Linear)
self.assertEqual(
type(model.group_norm), torch.nn.qat.modules.GroupNorm)
self.assertEqual(
type(model.instance_norm1d),
torch.nn.qat.modules.InstanceNorm1d)
self.assertEqual(
type(model.instance_norm2d),
torch.nn.qat.modules.InstanceNorm2d)
self.assertEqual(
type(model.instance_norm3d),
torch.nn.qat.modules.InstanceNorm3d)
self.assertEqual(
type(model.layer_norm), torch.nn.qat.modules.LayerNorm)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.group_norm), nnq.GroupNorm)
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.group_norm), nnq.GroupNorm)
self.assertEqual(
type(model.instance_norm1d), nnq.InstanceNorm1d)
self.assertEqual(
type(model.instance_norm2d), nnq.InstanceNorm2d)
self.assertEqual(
type(model.instance_norm3d), nnq.InstanceNorm3d)
self.assertEqual(type(model.layer_norm), nnq.LayerNorm)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model = quantize_qat(
NormalizationQATTestModel(qengine), test_only_train_fn,
self.train_data)
checkQuantized(model)
def test_eval_only_fake_quant(self):
r"""Using FakeQuant in evaluation only mode,
this is useful for estimating accuracy loss when we quantize the
network
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualLinearQATModel(qengine)
model = prepare_qat(model)
self.checkObservers(model)
model.eval()
test_only_eval_fn(model, self.calib_data)
def test_conv_linear(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualConvLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.img_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv), nnq.Conv2d)
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.img_data)
self.checkScriptable(model, self.img_data)
checkQuantized(model)
model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, self.img_data)
checkQuantized(model)
def test_train_save_load_eval(self):
r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
During eval, we first call prepare_qat and conver on the model and then load the state_dict
and compare results against original model
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
model = prepare_qat(model)
fq_state_dict = model.state_dict()
test_only_train_fn(model, self.train_data)
model = convert(model)
quant_state_dict = model.state_dict()
x = torch.rand(2, 5, dtype=torch.float)
ref = model(x)
# Create model again for eval. Check result using quantized state_dict
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
torch.quantization.prepare_qat(model, inplace=True)
new_state_dict = model.state_dict()
# Check to make sure the model after prepare_qat has the same state_dict as original.
self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys()))
torch.quantization.convert(model, inplace=True)
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
# Check model created using prepare has same state dict as quantized state_dict
model = TwoLayerLinearModel()
model.eval()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
class TestFunctionalModule(QuantizationTestCase):
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
@given(train_mode=st.booleans())
def test_functional_module(self, train_mode):
model = ModelWithFunctionals()
x = torch.rand(10, 1, dtype=torch.float)
xq = torch.quantize_per_tensor(x, 0.01, 30, torch.quint8)
self.checkScriptable(model, [(x, x)], check_save_load=True)
if train_mode:
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = prepare_qat(model)
else:
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
model = prepare(model)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkObservers(model)
# Calibrate
model(xq.dequantize())
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.assertEqual(type(model.myadd), torch.nn.quantized.QFunctional)
self.assertEqual(type(model.mycat), torch.nn.quantized.QFunctional)
self.assertEqual(type(model.myadd_relu), torch.nn.quantized.QFunctional)
checkQuantized(model)
self.checkScriptable(model, [(xq, xq)], check_save_load=True)
@skipIfNoFBGEMM
class TestFusion(QuantizationTestCase):
def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train()
# Test step by step fusion
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
msg="Fused Conv + BN + Relu first layer")
self.assertEqual(type(model.bn1), torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped BN)")
self.assertEqual(type(model.relu1), torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped Relu)")
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
msg="Fused submodule Conv + BN")
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
msg="Fused submodule Conv + BN (skipped BN)")
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
msg="Non-fused submodule ReLU")
model = prepare_qat(model)
self.checkObservers(model)
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
checkQAT(model)
test_only_train_fn(model, self.img_data_1d)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
test_only_eval_fn(model, self.img_data_1d)
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, self.img_data_1d)
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
checkQuantized(model)
def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig)
model.eval()
model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.bn1), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped BN)")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2), nni.ConvReLU3d,
msg="Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.bn2), nni.BNReLU3d,
msg="Fused BN + Relu first layer (Relu is folded))")
self.assertEqual(type(model.relu3), nn.Identity,
msg="Fused BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2[0]), nn.Conv3d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv2[1]), nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.relu2), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
msg="Fused Conv + Relu for Conv1d (folded BN)")
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
msg="Fused Conv + Relu for Conv1d ")
self.assertEqual(type(model.conv3[1]), nn.ReLU,
msg="Fused Conv + Relu for Conv1d")
self.assertEqual(type(model.bn3), nn.Identity,
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.sub1.bn), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
msg="Non-fused submodule ReLU")
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.img_data_1d)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
self.assertEqual(type(model.bn2), nniq.BNReLU3d)
test_only_eval_fn(model, self.img_data_1d)
checkQuantized(model)
model = ModelForFusion(default_qconfig).eval()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn'],
['conv3', 'bn3', 'relu4']])
model = quantize(model, test_only_eval_fn, self.img_data_1d)
checkQuantized(model)
def test_fusion_sequential_model_train(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().train()
model.to(torch.float)
fuse_modules(model, [['conv1', 'relu1'] ,
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']], inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + Relu: Identity")
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
prepare_qat(model, inplace=True)
self.checkObservers(model)
model(self.img_data[0][0])
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
checkQAT(model)
model(self.img_data[1][0])
convert(model, inplace=True)
model(self.img_data[1][0])
self.checkModelWithSequentialQuantized(model)
def test_fusion_sequential_model_eval(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().eval()
model.to(torch.float)
fuse_modules(model, [['conv1', 'relu1'] ,
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']], inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + Relu: Identity")
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
prepare(model, inplace=True)
self.checkObservers(model)
model(self.img_data[0][0])
convert(model, inplace=True)
model(self.img_data[1][0])
self.checkModelWithSequentialQuantized(model)
def checkModelWithSequentialQuantized(self, model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d)
self.assertEqual(type(model.features[i][1]), nn.Identity)
self.assertEqual(type(model.features[i][2]), nn.Identity)
self.assertEqual(type(model.classifier[0]), nniq.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
def test_fusion_conv_with_bias(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ModelForFusionWithBias().train()
# output with no fusion.
out_ref = model(self.img_data[0][0])
model.qconfig = QConfig(activation=torch.nn.Identity,
weight=torch.nn.Identity)
model = fuse_modules(model, [["conv1", "bn1", "relu1"],
["conv2", "bn2"]])
prep_model = prepare_qat(model, inplace=False)
# output with fusion but no observers.
out_fused = prep_model(self.img_data[0][0])
self.assertEqual(out_ref, out_fused)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
prepare_qat(model, inplace=True)
model(self.img_data[0][0])
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
self.assertEqual(type(model.bn2), nn.Identity)
checkQAT(model)
class TestModelNumerics(QuantizationTestCase):
def test_float_quant_compare_per_tensor(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
torch.manual_seed(42)
my_model = ModelMultipleOps().to(torch.float32)
my_model.eval()
calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32)
out_ref = my_model(eval_data)
qModel = torch.quantization.QuantWrapper(my_model)
qModel.eval()
qModel.qconfig = torch.quantization.default_qconfig
torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare(qModel, inplace=True)
qModel(calib_data)
torch.quantization.convert(qModel, inplace=True)
out_q = qModel(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
# Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
# output
self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB')
def test_float_quant_compare_per_channel(self):
# Test for per-channel Quant
torch.manual_seed(67)
my_model = ModelMultipleOps().to(torch.float32)
my_model.eval()
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
out_ref = my_model(eval_data)
q_model = torch.quantization.QuantWrapper(my_model)
q_model.eval()
q_model.qconfig = torch.quantization.default_per_channel_qconfig
torch.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare(q_model)
q_model(calib_data)
torch.quantization.convert(q_model)
out_q = q_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
# Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
def test_fake_quant_true_quant_compare(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
torch.manual_seed(67)
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
my_model.eval()
out_ref = my_model(eval_data)
fq_model = torch.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = torch.quantization.default_qat_qconfig
torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.quantization.disable_fake_quant)
fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
fq_model(calib_data)
fq_model.apply(torch.quantization.enable_fake_quant)
fq_model.apply(torch.quantization.disable_observer)
out_fq = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
# Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
torch.quantization.convert(fq_model)
out_q = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')
# Test to compare weight only quantized model numerics and
# activation only quantized model numerics with float
def test_weight_only_activation_only_fakequant(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
torch.manual_seed(67)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
qconfigset = set([torch.quantization.default_weight_only_qconfig,
torch.quantization.default_activation_only_qconfig])
SQNRTarget = [35, 45]
for idx, qconfig in enumerate(qconfigset):
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
my_model.eval()
out_ref = my_model(eval_data)
fq_model = torch.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = qconfig
torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.quantization.disable_fake_quant)
fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
fq_model(calib_data)
fq_model.apply(torch.quantization.enable_fake_quant)
fq_model.apply(torch.quantization.disable_observer)
out_fq = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead.")