blob: f32ab8b7663049e9a880453f7c71b134096937a3 [file] [log] [blame]
import unittest
import math
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
from torch.nn.utils.rnn import PackedSequence
from torch.quantization import \
get_observer_dict, default_weight_observer, \
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
default_dynamic_qconfig, per_channel_dynamic_qconfig, HistogramObserver, MinMaxObserver, \
PerChannelMinMaxObserver, RecordingObserver, MovingAverageMinMaxObserver, \
MovingAveragePerChannelMinMaxObserver, QuantWrapper, default_eval_fn, \
float16_dynamic_qconfig, MinMaxDynamicQuantObserver
from torch.quantization import QConfig
from torch.quantization import default_histogram_observer
from torch.quantization import default_observer
from torch.quantization import default_per_channel_weight_observer
from torch.quantization import default_per_channel_qconfig
from torch.quantization._quantize_script import quantize_script, quantize_dynamic_script
from torch.testing._internal.common_utils import TEST_WITH_UBSAN, IS_WINDOWS
from torch.testing._internal.common_quantization import QuantizationTestCase, \
AnnotatedSingleLayerLinearModel, SingleLayerLinearModel, \
AnnotatedConvModel, ConvModel, \
AnnotatedConvBnModel, ConvBnModel, \
SkipQuantModel, QuantStubModel, \
ModelForFusion, ModelWithSequentialFusion, ManualLinearQATModel, ManualConvLinearQATModel, \
ModelWithFunctionals, \
test_only_eval_fn, test_only_train_fn, \
prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \
TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel, \
ModelWithNoQconfigPropagation, ModelForFusionWithBias, \
ActivationsTestModel, ActivationsQATTestModel, NormalizationTestModel
from torch.testing._internal.common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel
from torch.testing._internal.common_quantization import AnnotatedSkipQuantModel
from torch.testing._internal.common_quantized import override_quantized_engine
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
import io
import copy
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
class TestPostTrainingStatic(QuantizationTestCase):
@given(qconfig=st.sampled_from((torch.quantization.default_qconfig, torch.quantization.default_per_channel_qconfig)))
def test_single_layer(self, qconfig):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
to nnq.Linear which is the quantized version of the module
"""
model = AnnotatedSingleLayerLinearModel()
model.qconfig = qconfig
model = prepare(model)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkWrappedQuantizedLinear(model.fc1)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API - out of place version
base = AnnotatedSingleLayerLinearModel()
base.qconfig = qconfig
keys_before = set(list(base.state_dict().keys()))
model = quantize(base, test_only_eval_fn, self.calib_data)
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
# in-place version
model = AnnotatedSingleLayerLinearModel()
model.qconfig = qconfig
quantize(model, test_only_eval_fn, self.calib_data, inplace=True)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
model = AnnotatedTwoLayerLinearModel()
model = prepare(model)
self.checkNoPrepModules(model)
self.checkObservers(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkWrappedQuantizedLinear(model.fc2)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
model = AnnotatedNestedModel()
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkNoPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
model = prepare(model)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.checkWrappedQuantizedLinear(model.fc3)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkLinear(model.sub2.fc2)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested2(self):
model = AnnotatedSubNestedModel()
model = prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkHasPrepModules(model.sub2)
self.checkNoPrepModules(model.sub2.module.fc1)
self.checkNoPrepModules(model.sub2.module.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkQuantizedLinear(model.sub2.module.fc1)
self.checkQuantizedLinear(model.sub2.module.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
model = AnnotatedCustomConfigNestedModel()
model = prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkWrappedQuantizedLinear(model.sub2.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_skip_quant(self):
r"""The case when we want to skip quantizing some layers
"""
model = AnnotatedSkipQuantModel()
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkLinear(model.fc)
self.checkQuantDequant(model.sub)
self.checkQuantizedLinear(model.sub.module.fc1)
self.checkQuantizedLinear(model.sub.module.fc2)
self.assertEqual(type(model.sub.module.relu1), nnq.ReLU)
self.assertEqual(type(model.sub.module.relu2), nnq.ReLU)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedSkipQuantModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_manual(self):
r"""User inserts QuantStub and DeQuantStub in model code
and call the quantization utility functions.
"""
model = QuantStubModel()
# propagate the qconfig of parents to children, model is changed
# inplace
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
@given(qconfig=st.sampled_from((torch.quantization.default_qconfig, torch.quantization.default_per_channel_qconfig)))
def test_resnet_base(self, qconfig):
r"""Test quantization for bottleneck topology used in resnet/resnext
and add coverage for conversion of average pool and float functional
"""
model = ResNetBase().float().eval()
model = QuantWrapper(model)
model.qconfig = qconfig
fuse_list = ['module.conv1', 'module.bn1', 'module.relu1']
fuse_modules(model, fuse_list, inplace=True)
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.img_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d)
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
def test_normalization(self):
r"""
Test quantization of normalization layers
"""
model = NormalizationTestModel()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
prepare(model, inplace=True)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model.layer_norm)
self.assertEqual(type(model.layer_norm), nnq.LayerNorm)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model_oneline = quantize(
NormalizationTestModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
@given(qengine=st.sampled_from(("qnnpack", "fbgemm")))
def test_save_load_state_dict(self, qengine):
r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict
Load the quantized state_dict for eval and compare results against original model
"""
if qengine == 'qnnpack':
if IS_WINDOWS or TEST_WITH_UBSAN:
return
with override_quantized_engine(qengine):
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
model = prepare(model)
# calibrate
test_only_eval_fn(model, self.calib_data)
model = convert(model)
x = torch.rand(2, 5, dtype=torch.float)
ref = model(x)
quant_state_dict = model.state_dict()
# Create model again for eval
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
model = prepare(model)
model = convert(model)
new_state_dict = model.state_dict()
# Check to make sure the state dict keys match original model after convert.
self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys()))
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
def test_activations(self):
r"""
Test quantization of activations
"""
model = ActivationsTestModel()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
prepare(model, inplace=True)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model.hardswish)
self.assertEqual(type(model.hardswish), nnq.Hardswish)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model_oneline)
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
class TestPostTrainingDynamic(QuantizationTestCase):
def test_single_layer(self):
r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
make sure it is swapped to nnqd.Linear which is the quantized version of
the module
"""
for dtype in [torch.qint8, torch.float16]:
model = SingleLayerLinearDynamicModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc1': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.fc1, dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API - out of place version
base = SingleLayerLinearDynamicModel()
keys_before = set(list(base.state_dict().keys()))
model = quantize_dynamic(base, qconfig_dict)
checkQuantized(model)
keys_after = set(list(base.state_dict().keys()))
self.assertEqual(keys_before, keys_after) # simple check that nothing changed
# in-place version
model = SingleLayerLinearDynamicModel()
quantize_dynamic(model, qconfig_dict, inplace=True)
checkQuantized(model)
# Test set qconfig
model = SingleLayerLinearDynamicModel()
quantize_dynamic(model, set([nn.Linear]), inplace=True, dtype=dtype)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
for dtype in [torch.qint8, torch.float16]:
model = TwoLayerLinearModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc2': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkDynamicQuantizedLinear(model.fc2, dtype=dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict)
checkQuantized(model)
# Test set API
model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}, dtype=dtype)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc3': qconfig,
'sub2.fc1': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkLinear(model.sub1.fc)
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
self.checkLinear(model.sub2.fc2)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}, dtype=dtype)
checkQuantized(model)
def test_nested2(self):
r"""Another test case for quantized, we will quantize all submodules
of submodule sub2
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc3': qconfig,
'sub2': qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
checkQuantized(model)
# Test set API
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}, dtype=dtype)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dynamic_dict = {
'fc3': qconfig,
'sub2': qconfig,
'sub2.fc1': qconfig
}
prepare_dynamic(model, qconfig_dynamic_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict)
checkQuantized(model)
# Test set API
model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}, dtype=dtype)
checkQuantized(model)
def test_type_match_rule(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized
"""
for dtype in [torch.qint8, torch.float16]:
model = NestedModel().eval()
qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
qconfig_dict = {
'fc3': None,
'sub2.fc1': None,
torch.nn.Linear: qconfig
}
prepare_dynamic(model, qconfig_dict)
test_only_eval_fn(model, self.calib_data)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=dtype)
self.checkLinear(model.fc3)
self.checkLinear(model.sub2.fc1)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
checkQuantized(model)
def test_per_channel_quantize(self):
r"""Test quantization for per_channel dynamic quantization
"""
model = NestedModel().eval()
qconfig_dict = {
torch.nn.Linear: per_channel_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dict)
test_only_eval_fn(model, self.calib_data)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=torch.qint8)
self.checkDynamicQuantizedLinear(model.fc3, dtype=torch.qint8)
self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=torch.qint8)
self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=torch.qint8)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data, check_save_load=True)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
@unittest.skip("temporarily disable the test")
@given(qengine=st.sampled_from(("fbgemm",)))
def test_quantized_rnn(self, qengine):
d_in, d_hid = 2, 2
# TODO: qlinear_prepack_fp16 currently doesn't support QNNPACK
# re-add "qnnpack" to the engine set when this is supported
with override_quantized_engine(qengine):
model = LSTMDynamicModel().eval()
cell = model.lstm
# Replace parameter values s.t. the range of values is exactly
# 255, thus we will have 0 quantization error in the quantized
# GEMM call. This i s for testing purposes.
#
# Note that the current implementation does not support
# accumulation values outside of the range representable by a
# 16 bit integer, instead resulting in a saturated value. We
# must take care that in our test we do not end up with a dot
# product that overflows the int16 range, e.g.
# (255*127+255*127) = 64770. So, we hardcode the test values
# here and ensure a mix of signedness.
vals = [[100, -155],
[100, -155],
[-155, 100],
[-155, 100],
[100, -155],
[-155, 100],
[-155, 100],
[100, -155]]
if isinstance(cell, torch.nn.LSTM):
num_chunks = 4
vals = vals[:d_hid * num_chunks]
cell.weight_ih_l0 = torch.nn.Parameter(
torch.tensor(vals, dtype=torch.float),
requires_grad=False)
cell.weight_hh_l0 = torch.nn.Parameter(
torch.tensor(vals, dtype=torch.float),
requires_grad=False)
ref = copy.deepcopy(cell)
model_int8 = quantize_dynamic(model=model, dtype=torch.qint8)
model_fp16 = quantize_dynamic(model=model, dtype=torch.float16)
# Smoke test extra reprs
self.assertTrue('DynamicQuantizedLSTM' in str(model_int8))
self.assertTrue('DynamicQuantizedLSTM' in str(model_fp16))
cell_int8 = model_int8.lstm
cell_fp16 = model_fp16.lstm
assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \
'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
assert type(cell_fp16) == torch.nn.quantized.dynamic.LSTM, \
'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
niter = 10
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
h0_vals = [[-155, 100],
[-155, 155],
[100, -155]]
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
if isinstance(ref, torch.nn.LSTM):
hiddens = (hx, cx)
ref_out, ref_hid = ref(x, hiddens)
# Compare int8 quantized to unquantized
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
torch.testing.assert_allclose(output_int8, ref_out)
self.assertEqual(output_int8, ref_out)
for out_val, ref_val in zip(final_hiddens_int8, ref_hid):
torch.testing.assert_allclose(out_val, ref_val)
class ScriptWrapper(torch.nn.Module):
def __init__(self, cell):
super(ScriptWrapper, self).__init__()
self.cell = cell
def forward(self, x, hiddens):
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
# -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
return self.cell(x, hiddens)
# TODO: TorchScript overloads don't work without this wrapper
cell_script = torch.jit.script(ScriptWrapper(cell_int8))
out_script, hid_script = cell_script(x, hiddens)
self.assertEqual(len(out_script), len(ref_out))
for out_val, ref_val in zip(out_script, ref_out):
torch.testing.assert_allclose(out_val, ref_val)
# Test save/load
b = io.BytesIO()
torch.jit.save(cell_script, b)
b.seek(0)
loaded = torch.jit.load(b)
out_loaded, hid_loaded = loaded(x, hiddens)
for loaded_val, ref_val in zip(out_loaded, ref_out):
torch.testing.assert_allclose(loaded_val, ref_val)
# Compare fp16 quantized to unquantized
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
torch.testing.assert_allclose(output_fp16, ref_out)
self.assertEqual(output_fp16, ref_out)
for out, ref_val in zip(final_hiddens_fp16, ref_hid):
torch.testing.assert_allclose(out, ref_val)
# Test tracing
# TODO: TorchScript overloads don't work without this wrapper
cell_trace = torch.jit.trace(ScriptWrapper(cell_int8), (x, (hx, cx)))
out_script, hid_script = cell_trace(x, hiddens)
for out_val, ref_val in zip(out_script, ref_out):
torch.testing.assert_allclose(out_val, ref_val)
# print(cell_trace.code)
# Test save/load
b = io.BytesIO()
torch.jit.save(cell_trace, b)
b.seek(0)
loaded = torch.jit.load(b)
out_loaded, hid_loaded = loaded(x, hiddens)
for loaded_val, ref_val in zip(out_loaded, ref_out):
torch.testing.assert_allclose(loaded_val, ref_val)
# Compare fp16 quantized to unquantized
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
torch.testing.assert_allclose(output_fp16, ref_out)
self.assertEqual(output_fp16, ref_out)
for out, ref_val in zip(final_hiddens_fp16, ref_hid):
torch.testing.assert_allclose(out, ref_val)
class ScriptWrapperPacked(torch.nn.Module):
def __init__(self, cell):
super(ScriptWrapperPacked, self).__init__()
self.cell = cell
def forward(self,
x, # type: PackedSequence
hiddens # type: Tuple[torch.Tensor, torch.Tensor]
):
# type: (...) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]
return self.cell(x, hiddens)
cell_packed = torch.jit.script(ScriptWrapperPacked(cell_int8))
packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2]))
ref_out_packed, ref_hid_packed = ref(packed_input, hiddens)
output_packed, hiddens_packed = cell_packed(packed_input, hiddens)
for packed_val, ref_val in zip(output_packed, ref_out_packed):
if isinstance(packed_val, torch.Tensor):
torch.testing.assert_allclose(packed_val, ref_val)
else:
self.assertEqual(packed_val, ref_val)
# Test save/load
b = io.BytesIO()
torch.jit.save(cell_packed, b)
b.seek(0)
loaded_packed = torch.jit.load(b)
out_loaded_packed, hid_loaded_packed = loaded_packed(packed_input, hiddens)
for packed_val, ref_val in zip(out_loaded_packed, ref_out_packed):
if isinstance(packed_val, torch.Tensor):
torch.testing.assert_allclose(packed_val, ref_val)
else:
self.assertEqual(packed_val, ref_val)
# Test default instantiation
seq_len = 128
batch = 16
input_size = 3
hidden_size = 7
num_layers = 2
bias = True
bidirectional = False
x = torch.rand(seq_len, batch, input_size)
h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size)
c = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size)
dtype = torch.qint8
cell_dq = torch.nn.quantized.dynamic.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=False,
dropout=0.0,
bidirectional=bidirectional,
dtype=dtype)
y, (h, c) = cell_dq(x, (h, c))
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
class TestQuantizationAwareTraining(QuantizationTestCase):
def test_manual(self):
model = ManualLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model = quantize_qat(ManualLinearQATModel(), test_only_train_fn,
self.train_data)
checkQuantized(model)
def test_activations(self):
model = ActivationsQATTestModel()
model = prepare_qat(model)
self.assertEqual(type(model.fc1), torch.nn.qat.modules.Linear)
self.assertEqual(type(model.hardswish), torch.nn.qat.modules.Hardswish)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.hardswish), nnq.Hardswish)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model = quantize_qat(ActivationsQATTestModel(), test_only_train_fn,
self.train_data)
checkQuantized(model)
def test_eval_only_fake_quant(self):
r"""Using FakeQuant in evaluation only mode,
this is useful for estimating accuracy loss when we quantize the
network
"""
model = ManualLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
model.eval()
test_only_eval_fn(model, self.calib_data)
def test_conv_linear(self):
model = ManualConvLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.img_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv), nnq.Conv2d)
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.img_data)
self.checkScriptable(model, self.img_data)
checkQuantized(model)
model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, self.img_data)
checkQuantized(model)
@given(qengine=st.sampled_from(("qnnpack", "fbgemm")))
def test_train_save_load_eval(self, qengine):
r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
During eval, we first call prepare_qat and conver on the model and then load the state_dict
and compare results against original model
"""
if qengine == 'qnnpack':
if IS_WINDOWS or TEST_WITH_UBSAN:
return
with override_quantized_engine(qengine):
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
model = prepare_qat(model)
fq_state_dict = model.state_dict()
test_only_train_fn(model, self.train_data)
model = convert(model)
quant_state_dict = model.state_dict()
x = torch.rand(2, 5, dtype=torch.float)
ref = model(x)
# Create model again for eval. Check result using quantized state_dict
model = TwoLayerLinearModel()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
torch.quantization.prepare_qat(model, inplace=True)
new_state_dict = model.state_dict()
# Check to make sure the model after prepare_qat has the same state_dict as original.
self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys()))
torch.quantization.convert(model, inplace=True)
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
# Check model created using prepare has same state dict as quantized state_dict
model = TwoLayerLinearModel()
model.eval()
model = torch.quantization.QuantWrapper(model)
model.qconfig = torch.quantization.get_default_qconfig(qengine)
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
@unittest.skipUnless(
'fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
class TestGraphModePostTrainingStatic(QuantizationTestCase):
def test_single_linear(self):
r"""Compare the result of quantizing single linear layer in
eager mode and graph mode
"""
# eager mode
annotated_linear_model = AnnotatedSingleLayerLinearModel().eval()
linear_model = SingleLayerLinearModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
model_eager = quantize(annotated_linear_model, test_only_eval_fn,
self.calib_data)
qconfig_dict = {'': default_qconfig}
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
model_script = torch.jit.script(linear_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_script(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
def test_observer_with_ignored_function(self):
r"""Test observers with ignored function and make sure it works in
graph mode
"""
# eager mode
annotated_linear_model = AnnotatedSingleLayerLinearModel().eval()
for qconfig in [
QConfig(
activation=default_observer,
weight=default_weight_observer),
QConfig(
activation=default_histogram_observer,
weight=default_weight_observer),
QConfig(
activation=default_observer,
weight=default_per_channel_weight_observer),
]:
annotated_linear_model.qconfig = qconfig
linear_model = SingleLayerLinearModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
model_eager = quantize(annotated_linear_model, test_only_eval_fn,
self.calib_data)
qconfig_dict = {'': qconfig}
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
model_script = torch.jit.script(linear_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_script(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
def test_conv(self):
r"""Compare the result of quantizing conv layer in
eager mode and graph mode
"""
# eager mode
annotated_conv_model = AnnotatedConvModel().eval()
conv_model = ConvModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach())
model_eager = quantize(annotated_conv_model, default_eval_fn,
self.img_data)
qconfig_dict = {'': default_qconfig}
model_traced = torch.jit.trace(conv_model, self.img_data[0][0])
model_script = torch.jit.script(conv_model)
result_eager = model_eager(self.img_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_script(
model_under_test,
qconfig_dict,
default_eval_fn,
[self.img_data],
inplace=False)
self.assertEqual(model_quantized(self.img_data[0][0]), result_eager)
@unittest.skip("This doesn't work right now, re-enable after fold_convbn is fixed")
def test_conv_bn(self):
r"""Compare the result of quantizing conv + bn layer in
eager mode and graph mode
"""
# eager mode
conv_model = AnnotatedConvBnModel().eval()
conv_model_to_script = ConvBnModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
conv_model_to_script.conv.weight = torch.nn.Parameter(conv_model.conv.weight.detach())
fuse_modules(conv_model, ['conv', 'bn'], inplace=True)
model_eager = quantize(conv_model, default_eval_fn,
self.img_data)
qconfig_dict = {
'': default_qconfig
}
model_script = quantize_script(
torch.jit.script(conv_model_to_script),
qconfig_dict,
default_eval_fn,
[self.img_data],
inplace=False)
result_eager = model_eager(self.img_data[0][0])
result_script = model_script(self.img_data[0][0])
self.assertEqual(result_eager, result_script)
def test_nested(self):
# Eager mode
eager_model = AnnotatedNestedModel().eval()
# Graph mode
script_model = NestedModel().eval()
# Copy weights for eager_model
script_model.sub1.fc.weight = torch.nn.Parameter(eager_model.sub1.fc.weight.detach())
script_model.sub1.fc.bias = torch.nn.Parameter(eager_model.sub1.fc.bias.detach())
script_model.sub2.fc1.weight = torch.nn.Parameter(eager_model.sub2.fc1.module.weight.detach())
script_model.sub2.fc1.bias = torch.nn.Parameter(eager_model.sub2.fc1.module.bias.detach())
script_model.sub2.fc2.weight = torch.nn.Parameter(eager_model.sub2.fc2.weight.detach())
script_model.sub2.fc2.bias = torch.nn.Parameter(eager_model.sub2.fc2.bias.detach())
script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach())
script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach())
model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data)
qconfig_dict = {
'sub2.fc1': default_per_channel_qconfig,
'fc3': default_qconfig
}
model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
model_script = torch.jit.script(script_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_script(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
def test_skip_quant(self):
""" Test None qconfig
"""
# Eager mode
eager_model = AnnotatedSkipQuantModel().eval()
# Graph mode
script_model = SkipQuantModel().eval()
# Copy weights for eager_model
script_model.sub.fc1.weight = torch.nn.Parameter(eager_model.sub.module.fc1.weight.detach())
script_model.sub.fc1.bias = torch.nn.Parameter(eager_model.sub.module.fc1.bias.detach())
script_model.sub.fc2.weight = torch.nn.Parameter(eager_model.sub.module.fc2.weight.detach())
script_model.sub.fc2.bias = torch.nn.Parameter(eager_model.sub.module.fc2.bias.detach())
script_model.fc.weight = torch.nn.Parameter(eager_model.fc.weight.detach())
script_model.fc.bias = torch.nn.Parameter(eager_model.fc.bias.detach())
model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data)
qconfig_dict = {
'': default_qconfig,
'fc': None
}
model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
model_script = torch.jit.script(script_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_script(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
def test_single_linear_dynamic(self):
r"""Compare the result of dynamic quantization of single linear layer in
eager mode and graph mode.
"""
# eager mode
annotated_linear_model = AnnotatedSingleLayerLinearModel().eval()
linear_model = SingleLayerLinearModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
qconfig_dict = {'': default_dynamic_qconfig}
model_eager = quantize_dynamic(annotated_linear_model, qconfig_dict)
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
model_script = torch.jit.script(linear_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_dynamic_script(
model_under_test,
qconfig_dict,
[self.calib_data[0][0]])
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
# Check to make sure choose_qparams->quant->dequant->linear is numerically
# equivalent to the final quantized model.
model_fake_quantized = quantize_dynamic_script(
model_under_test,
qconfig_dict,
[self.calib_data[0][0]],
debug=True)
self.assertEqual(model_fake_quantized(self.calib_data[0][0]), result_eager)
class TestFunctionalModule(QuantizationTestCase):
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
@given(train_mode=st.booleans())
def test_functional_module(self, train_mode):
model = ModelWithFunctionals()
x = torch.rand(10, 1, dtype=torch.float)
xq = torch.quantize_per_tensor(x, 0.01, 30, torch.quint8)
self.checkScriptable(model, [(x, x)], check_save_load=True)
if train_mode:
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = prepare_qat(model)
else:
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
model = prepare(model)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkObservers(model)
# Calibrate
model(xq.dequantize())
model = convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.assertEqual(type(model.myadd), torch.nn.quantized.QFunctional)
self.assertEqual(type(model.mycat), torch.nn.quantized.QFunctional)
self.assertEqual(type(model.myadd_relu), torch.nn.quantized.QFunctional)
checkQuantized(model)
self.checkScriptable(model, [(xq, xq)], check_save_load=True)
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
class TestFusion(QuantizationTestCase):
def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train()
# Test step by step fusion
model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
"Fused Conv + BN + Relu first layer")
self.assertEqual(type(model.bn1), torch.nn.Identity,
"Fused Conv + BN + Relu (skipped BN)")
self.assertEqual(type(model.relu1), torch.nn.Identity,
"Fused Conv + BN + Relu (skipped Relu)")
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
"Fused submodule Conv + BN")
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
"Fused submodule Conv + BN (skipped BN)")
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
"Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
"Non-fused submodule ReLU")
model = prepare_qat(model)
self.checkObservers(model)
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
checkQAT(model)
test_only_train_fn(model, self.img_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, self.img_data)
checkQuantized(model)
def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig)
model.eval()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'] ,
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
"Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
"Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
"Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.bn1), nn.Identity,
"Fused Conv + BN + Relu second layer (Skipped BN)")
self.assertEqual(type(model.relu1), nn.Identity,
"Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2), nni.ConvReLU3d,
"Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.bn2), nni.BNReLU3d,
"Fused BN + Relu first layer (Relu is folded))")
self.assertEqual(type(model.relu3), nn.Identity,
"Fused BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2[0]), nn.Conv3d,
"Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv2[1]), nn.ReLU,
"Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.relu2), nn.Identity,
"Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
"Fused submodule Conv + folded BN")
self.assertEqual(type(model.sub1.bn), nn.Identity,
"Fused submodule (skipped BN)")
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
"Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
"Non-fused submodule ReLU")
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.img_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
self.assertEqual(type(model.bn2), nniq.BNReLU3d)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
model = ModelForFusion(default_qconfig).eval()
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn']])
model = quantize(model, test_only_eval_fn, self.img_data)
checkQuantized(model)
def test_fusion_sequential_model_train(self):
model = ModelWithSequentialFusion().train()
model.to(torch.float)
fuse_modules(model, [['conv1', 'relu1'] ,
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']], inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
"Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
"Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
"Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
"Fused Conv + Relu: Identity")
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
"Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
"Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
"Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = default_qat_qconfig
prepare_qat(model, inplace=True)
self.checkObservers(model)
model(self.img_data[0][0])
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
"Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
"Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
"Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
checkQAT(model)
model(self.img_data[1][0])
convert(model, inplace=True)
model(self.img_data[1][0])
self.checkModelWithSequentialQuantized(model)
def test_fusion_sequential_model_eval(self):
model = ModelWithSequentialFusion().eval()
model.to(torch.float)
fuse_modules(model, [['conv1', 'relu1'] ,
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']], inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
"Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
"Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
"Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
"Fused Conv + Relu: Identity")
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
"Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
"Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
"Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = default_qconfig
prepare(model, inplace=True)
self.checkObservers(model)
model(self.img_data[0][0])
convert(model, inplace=True)
model(self.img_data[1][0])
self.checkModelWithSequentialQuantized(model)
def checkModelWithSequentialQuantized(self, model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d)
self.assertEqual(type(model.features[i][1]), nn.Identity)
self.assertEqual(type(model.features[i][2]), nn.Identity)
self.assertEqual(type(model.classifier[0]), nniq.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
def test_fusion_conv_with_bias(self):
model = ModelForFusionWithBias().train()
# output with no fusion.
out_ref = model(self.img_data[0][0])
model.qconfig = QConfig(activation=torch.nn.Identity,
weight=torch.nn.Identity)
model = fuse_modules(model, [["conv1", "bn1", "relu1"],
["conv2", "bn2"]])
prep_model = prepare_qat(model, inplace=False)
# output with fusion but no observers.
out_fused = prep_model(self.img_data[0][0])
self.assertEqual(out_ref, out_fused)
model.qconfig = default_qat_qconfig
prepare_qat(model, inplace=True)
model(self.img_data[0][0])
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
self.assertEqual(type(model.bn2), nn.Identity)
checkQAT(model)
class TestObserver(QuantizationTestCase):
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
# reduce_range cannot be true for symmetric quantization with uint8
if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric:
reduce_range = False
ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range),
MovingAverageMinMaxObserver(averaging_constant=0.5,
dtype=qdtype,
qscheme=qscheme,
reduce_range=reduce_range)]
for myobs in ObserverList:
# Calculate Qparams should return with a warning for observers with no data
qparams = myobs.calculate_qparams()
if type(myobs) == MinMaxObserver:
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
else:
# Moving average of min/max for x and y matches that of
# extreme values for x/y used for minmax observer
x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0])
result = myobs(x)
result = myobs(y)
self.assertEqual(result, y)
self.assertEqual(myobs.min_val, 1.0)
self.assertEqual(myobs.max_val, 8.0)
qparams = myobs.calculate_qparams()
if reduce_range:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.062745 * 255 / 127
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0313725 * 255 / 127
ref_zero_point = -64 if qdtype is torch.qint8 else 0
else:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.062745
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0313725
ref_zero_point = -128 if qdtype is torch.qint8 else 0
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
state_dict = myobs.state_dict()
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
loaded_obs.load_state_dict(loaded_dict)
loaded_qparams = loaded_obs.calculate_qparams()
self.assertEqual(myobs.min_val, loaded_obs.min_val)
self.assertEqual(myobs.max_val, loaded_obs.max_val)
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
reduce_range=st.booleans())
def test_per_tensor_dynamic_quant_observers(self, X, reduce_range):
X, (scale, zero_point, torch_type) = X
x = torch.from_numpy(X)
obs = MinMaxDynamicQuantObserver(dtype=torch.quint8, reduce_range=reduce_range)
result = obs(x)
qparams = obs.calculate_qparams()
ref = torch._choose_qparams_per_tensor(x, reduce_range)
self.assertEqual(ref[0], qparams[0])
self.assertEqual(ref[1], qparams[1])
def test_tensor_list_observer(self):
from torch.quantization.observer import _MinMaxTensorListObserver
x = [torch.tensor([1.0, 2.5, 3.5]),
torch.tensor([2.0, 4.5, 3.5]),
torch.tensor([4.0, 2.5, 3.5]), ]
obs = _MinMaxTensorListObserver()
obs(x)
qparams = obs.calculate_qparams()
ref_min_val = []
ref_max_val = []
ref_qparams = []
for i in x:
obs_ref = MinMaxObserver()
obs_ref(i)
ref_min_val.append(obs_ref.min_val)
ref_max_val.append(obs_ref.max_val)
ref_qparams.append(obs_ref.calculate_qparams())
for i in range(len(x)):
self.assertEqual(obs.min_val[i], ref_min_val[i])
self.assertEqual(obs.max_val[i], ref_max_val[i])
self.assertEqual(qparams[0][i], ref_qparams[i][0])
self.assertEqual(qparams[1][i], ref_qparams[i][1])
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)),
ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
# reduce_range cannot be true for symmetric quantization with uint8
if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
reduce_range = False
ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range,
ch_axis=ch_axis,
dtype=qdtype,
qscheme=qscheme),
MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5,
reduce_range=reduce_range,
ch_axis=ch_axis,
dtype=qdtype,
qscheme=qscheme)]
for myobs in ObserverList:
# Calculate qparams should work for empty observers
qparams = myobs.calculate_qparams()
x = torch.tensor(
[
[[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
[[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
]
)
if type(myobs) == MovingAveragePerChannelMinMaxObserver:
# Scaling the input tensor to model change in min/max values
# across batches
result = myobs(0.5 * x)
result = myobs(1.5 * x)
self.assertEqual(result, 1.5 * x)
else:
result = myobs(x)
self.assertEqual(result, x)
qparams = myobs.calculate_qparams()
ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]]
ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
per_channel_symmetric_ref_scales = [
[0.04705882, 0.06274509],
[0.03921569, 0.0627451],
[0.04705882, 0.0627451],
[0.05490196, 0.0627451],
]
per_channel_affine_ref_scales = [
[0.02352941, 0.04705882],
[0.03529412, 0.03137255],
[0.03921569, 0.03137255],
[0.04313726, 0.04313726],
]
per_channel_affine_qint8_zp = [
[-128, -43],
[-15, -128],
[-26, -128],
[-35, -58],
]
per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]
self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
if qscheme == torch.per_channel_symmetric:
ref_scales = per_channel_symmetric_ref_scales[ch_axis]
ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
else:
ref_scales = per_channel_affine_ref_scales[ch_axis]
ref_zero_points = (
per_channel_affine_qint8_zp[ch_axis]
if qdtype is torch.qint8
else per_channel_affine_quint8_zp[ch_axis]
)
if reduce_range:
ref_scales = [s * 255 / 127 for s in ref_scales]
ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype)))
self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))
# Test for serializability
state_dict = myobs.state_dict()
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme)
loaded_obs.load_state_dict(loaded_dict)
loaded_qparams = loaded_obs.calculate_qparams()
self.assertEqual(myobs.min_vals, loaded_obs.min_vals)
self.assertEqual(myobs.max_vals, loaded_obs.max_vals)
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
def test_observer_scriptable(self):
obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()]
for obs in obs_list:
scripted = torch.jit.script(obs)
x = torch.rand(3, 4)
obs(x)
scripted(x)
self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
buf = io.BytesIO()
torch.jit.save(scripted, buf)
buf.seek(0)
loaded = torch.jit.load(buf)
self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
# Check TensorListObserver
from torch.quantization.observer import _MinMaxTensorListObserver
obs = _MinMaxTensorListObserver()
scripted = torch.jit.script(obs)
x = [torch.rand(3, 4), torch.rand(4, 5)]
obs(x)
scripted(x)
self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
def test_no_qconfig_propagation(self):
model = ModelWithNoQconfigPropagation()
model.qconfig = torch.quantization.default_qconfig
model = prepare(model)
self.assertTrue(hasattr(model.fc1, 'qconfig'),
"QConfig is expected to propagate")
self.assertFalse(hasattr(model.no_quant_module, 'qconfig'),
"QConfig is expected to NOT propagate")
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
class TestRecordHistogramObserver(QuantizationTestCase):
def test_record_observer(self):
model = AnnotatedSingleLayerLinearModel()
model.qconfig = default_debug_qconfig
model = prepare(model)
# run the evaluation and dump all tensors
test_only_eval_fn(model, self.calib_data)
test_only_eval_fn(model, self.calib_data)
observer_dict = {}
get_observer_dict(model, observer_dict)
self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
'observer is not recorded in the dict')
self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data))
self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0], model(self.calib_data[0][0]))
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)))
def test_observer_scriptable(self, qdtype, qscheme):
obs = RecordingObserver(dtype=qdtype, qscheme=qscheme)
scripted = torch.jit.script(obs)
x = torch.rand(3, 4)
obs(x)
scripted(x)
self.assertTrue(torch.equal(obs.get_tensor_value()[0], scripted.get_tensor_value()[0]))
buf = io.BytesIO()
torch.jit.save(scripted, buf)
buf.seek(0)
loaded = torch.jit.load(buf)
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
# Calculate qparams should work for empty observers
qparams = myobs.calculate_qparams()
x = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
y = torch.tensor([5.0, 6.0, 7.0, 8.0])
out_x = myobs(x)
self.assertTrue(out_x.requires_grad)
myobs(y)
self.assertEqual(myobs.min_val, 2.0)
self.assertEqual(myobs.max_val, 8.0)
self.assertEqual(myobs.histogram, [2., 3., 3.])
qparams = myobs.calculate_qparams()
if reduce_range:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.0470588 * 255 / 127
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0235294 * 255 / 127
ref_zero_point = -64 if qdtype is torch.qint8 else 0
else:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.0470588
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0235294
ref_zero_point = -128 if qdtype is torch.qint8 else 0
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
# Test for serializability
state_dict = myobs.state_dict()
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
loaded_obs.load_state_dict(loaded_dict)
loaded_qparams = loaded_obs.calculate_qparams()
self.assertEqual(myobs.min_val, loaded_obs.min_val)
self.assertEqual(myobs.max_val, loaded_obs.max_val)
self.assertEqual(myobs.histogram, loaded_obs.histogram)
self.assertEqual(myobs.bins, loaded_obs.bins)
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
def test_histogram_observer_one_sided(self):
myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
x = torch.tensor([0.0, 0.3, 1.2, 1.7])
y = torch.tensor([0.1, 1.3, 2.0, 2.7])
myobs(x)
myobs(y)
self.assertEqual(myobs.min_val, 0)
qparams = myobs.calculate_qparams()
self.assertEqual(qparams[1].item(), 0)
if __name__ == '__main__':
run_tests()