blob: 87740960e1234cba72428b3eae8565e922e0f446 [file] [log] [blame]
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
from torch.quantization import (
DeQuantStub,
QuantStub,
convert,
default_eval_fn,
default_qconfig,
prepare,
quantize,
)
from torch.quantization._numeric_suite import (
Shadow,
ShadowLogger,
compare_model_outputs,
compare_model_stub,
compare_weights,
)
from torch.testing._internal.common_quantization import (
AnnotatedConvBnReLUModel,
AnnotatedConvModel,
QuantizationTestCase,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
)
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.mod1 = nn.Identity()
self.mod2 = nn.ReLU()
def forward(self, x):
x = self.mod1(x)
x = self.mod2(x)
return x
class ModelWithSubModules(torch.nn.Module):
def __init__(self):
super(ModelWithSubModules, self).__init__()
self.qconfig = default_qconfig
self.mod1 = SubModule()
self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.mod1(x)
x = self.conv(x)
x = self.dequant(x)
return x
class ModelWithFunctionals(torch.nn.Module):
def __init__(self):
super(ModelWithFunctionals, self).__init__()
self.mycat = nnq.FloatFunctional()
self.myadd = nnq.FloatFunctional()
self.mymul = nnq.FloatFunctional()
self.myadd_relu = nnq.FloatFunctional()
self.my_scalar_add = nnq.FloatFunctional()
self.my_scalar_mul = nnq.FloatFunctional()
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.mycat.cat([x, x, x])
x = self.myadd.add(x, x)
x = self.mymul.mul(x, x)
x = self.myadd_relu.add_relu(x, x)
w = self.my_scalar_add.add_scalar(x, -0.5)
w = self.my_scalar_mul.mul_scalar(w, 0.5)
w = self.dequant(w)
return w
class TestEagerModeNumericSuite(QuantizationTestCase):
def test_compare_weights(self):
r"""Compare the weights of float and quantized conv layer
"""
def compare_and_validate_results(float_model, q_model):
weight_dict = compare_weights(
float_model.state_dict(), q_model.state_dict()
)
self.assertEqual(len(weight_dict), 1)
for k, v in weight_dict.items():
self.assertTrue(v["float"].shape == v["quantized"].shape)
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model_list = [
AnnotatedConvModel(qengine),
AnnotatedConvBnReLUModel(qengine),
]
for model in model_list:
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, default_eval_fn, self.img_data)
compare_and_validate_results(model, q_model)
def test_compare_model_stub(self):
r"""Compare the output of quantized conv layer and its float shadow module
"""
def compare_and_validate_results(float_model, q_model, module_swap_list, data):
ob_dict = compare_model_stub(
float_model, q_model, module_swap_list, data, ShadowLogger
)
self.assertEqual(len(ob_dict), 1)
for k, v in ob_dict.items():
self.assertTrue(v["float"].shape == v["quantized"].shape)
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model_list = [
AnnotatedConvModel(qengine),
AnnotatedConvBnReLUModel(qengine),
]
data = self.img_data[0][0]
module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d]
for model in model_list:
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, default_eval_fn, self.img_data)
compare_and_validate_results(model, q_model, module_swap_list, data)
# Test adding stub to sub module
model = ModelWithSubModules().eval()
q_model = quantize(model, default_eval_fn, self.img_data)
module_swap_list = [SubModule]
ob_dict = compare_model_stub(
model, q_model, module_swap_list, data, ShadowLogger
)
self.assertTrue(isinstance(q_model.mod1, Shadow))
self.assertFalse(isinstance(q_model.conv, Shadow))
for k, v in ob_dict.items():
torch.testing.assert_allclose(
v["float"], v["quantized"].dequantize()
)
# Test adding stub to functionals
model = ModelWithFunctionals().eval()
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
q_model = prepare(model, inplace=False)
q_model(data)
q_model = convert(q_model)
module_swap_list = [nnq.FloatFunctional]
ob_dict = compare_model_stub(
model, q_model, module_swap_list, data, ShadowLogger
)
self.assertEqual(len(ob_dict), 6)
self.assertTrue(isinstance(q_model.mycat, Shadow))
self.assertTrue(isinstance(q_model.myadd, Shadow))
self.assertTrue(isinstance(q_model.mymul, Shadow))
self.assertTrue(isinstance(q_model.myadd_relu, Shadow))
self.assertTrue(isinstance(q_model.my_scalar_add, Shadow))
self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow))
for k, v in ob_dict.items():
self.assertTrue(v["float"].shape == v["quantized"].shape)
def test_compare_model_outputs(self):
r"""Compare the output of conv layer in quantized model and corresponding
output of conv layer in float model
"""
def compare_and_validate_results(float_model, q_model, data):
act_compare_dict = compare_model_outputs(float_model, q_model, data)
self.assertEqual(len(act_compare_dict), 2)
expected_act_compare_dict_keys = {"conv.stats", "quant.stats"}
self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
for k, v in act_compare_dict.items():
self.assertTrue(v["float"].shape == v["quantized"].shape)
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model_list = [
AnnotatedConvModel(qengine),
AnnotatedConvBnReLUModel(qengine),
]
data = self.img_data[0][0]
module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d]
for model in model_list:
model.eval()
if hasattr(model, "fuse_model"):
model.fuse_model()
q_model = quantize(model, default_eval_fn, self.img_data)
compare_and_validate_results(model, q_model, data)
# Test functionals
model = ModelWithFunctionals().eval()
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
q_model = prepare(model, inplace=False)
q_model(data)
q_model = convert(q_model)
act_compare_dict = compare_model_outputs(model, q_model, data)
self.assertEqual(len(act_compare_dict), 7)
expected_act_compare_dict_keys = {
"mycat.stats",
"myadd.stats",
"mymul.stats",
"myadd_relu.stats",
"my_scalar_add.stats",
"my_scalar_mul.stats",
"quant.stats",
}
self.assertTrue(
act_compare_dict.keys() == expected_act_compare_dict_keys
)
for k, v in act_compare_dict.items():
self.assertTrue(v["float"].shape == v["quantized"].shape)