blob: 8ff8af4edfe14985f8296410a2d4da295aa48746 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
import torch.nn.quantized as nnq
import torch.nn._intrinsic.quantized as nnq_fused
import torch.nn.quantized.functional as qF
from torch.nn.quantized.modules import Conv2d
from torch.nn._intrinsic.quantized import ConvReLU2d
from common_utils import run_tests, tempfile
from common_quantization import QuantizationTestCase
from hypothesis import given
from hypothesis import strategies as st
'''
Note that tests in this file are just API test, to make sure we wrapped the
quantized operator implementations correctly in the user facing APIs, these are
not correctness test for the underlying quantized operators. For correctness
test please see `caffe2/test/test_quantized.py`.
'''
class FunctionalAPITest(QuantizationTestCase):
def test_relu_api(self):
X = torch.arange(-5, 5, dtype=torch.float)
scale = 2.0
zero_point = 1
qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
qY = torch.ops.quantized.relu(qX)
qY_hat = qF.relu(qX)
self.assertEqual(qY, qY_hat)
class ModuleAPITest(QuantizationTestCase):
@given(
batch_size=st.integers(1, 5),
in_features=st.integers(16, 32),
out_features=st.integers(4, 8),
use_bias=st.booleans(),
use_fused=st.booleans(),
)
def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused):
"""test API functionality for nn.quantized.linear and nn._intrinsic.quantized.linear_relu"""
W = torch.rand(out_features, in_features).float()
W_q = torch.quantize_linear(W, 0.1, 4, torch.qint8)
X = torch.rand(batch_size, in_features).float()
X_q = torch.quantize_linear(X, 0.2, 10, torch.quint8)
B = torch.rand(out_features).float() if use_bias else None
B_q = torch.quantize_linear(B, W_q.q_scale() * X_q.q_scale(), 0, torch.qint32) if use_bias else None
scale = 0.5
zero_point = 3
if use_fused:
qlinear = nnq_fused.LinearReLU(in_features, out_features)
else:
qlinear = nnq.Linear(in_features, out_features)
qlinear.set_weight(W_q)
# Simple round-trip test to ensure weight()/set_weight() API
self.assertEqual(qlinear.weight(), W_q)
W_pack = qlinear._packed_weight
qlinear.bias = B_q if use_bias else None
qlinear.scale = float(scale)
qlinear.zero_point = int(zero_point)
Z_q = qlinear(X_q)
# Check if the module implementation matches calling the
# ops directly
if use_fused:
Z_ref = torch.ops.quantized.fbgemm_linear_relu(X_q, W_pack, B_q, scale, zero_point)
else:
Z_ref = torch.ops.quantized.fbgemm_linear(X_q, W_pack, B_q, scale, zero_point)
self.assertEqual(Z_ref, Z_q)
# Test serialization of quantized Linear Module using state_dict
model_dict = qlinear.state_dict()
self.assertEqual(model_dict['weight'], W_q)
if use_bias:
self.assertEqual(model_dict['bias'], B_q)
with tempfile.NamedTemporaryFile() as f:
torch.save(model_dict, f)
f.seek(0)
loaded_dict = torch.load(f)
for key in model_dict:
self.assertEqual(model_dict[key], loaded_dict[key])
if use_fused:
loaded_qlinear = nnq_fused.LinearReLU(in_features, out_features)
else:
loaded_qlinear = nnq.Linear(in_features, out_features)
loaded_qlinear.load_state_dict(loaded_dict)
linear_unpack = torch.ops.quantized.fbgemm_linear_unpack
self.assertEqual(linear_unpack(qlinear._packed_weight),
linear_unpack(loaded_qlinear._packed_weight))
if use_bias:
self.assertEqual(qlinear.bias, loaded_qlinear.bias)
self.assertEqual(qlinear.scale, loaded_qlinear.scale)
self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
self.assertTrue(hasattr(qlinear, '_packed_weight'))
self.assertTrue(hasattr(loaded_qlinear, '_packed_weight'))
self.assertTrue(hasattr(qlinear, 'weight'))
self.assertTrue(hasattr(loaded_qlinear, 'weight'))
self.assertEqual(qlinear.weight(), loaded_qlinear.weight())
self.assertEqual(qlinear.weight(), torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight))
Z_q2 = loaded_qlinear(X_q)
self.assertEqual(Z_q, Z_q2)
# test serialization of module directly
with tempfile.NamedTemporaryFile() as f:
torch.save(qlinear, f)
f.seek(0)
loaded = torch.load(f)
# This check is disabled pending an issue in PyTorch serialization:
# https://github.com/pytorch/pytorch/issues/24045
# self.assertEqual(qlinear.weight(), loaded.weight())
self.assertEqual(qlinear.bias, loaded.bias)
self.assertEqual(qlinear.scale, loaded.scale)
self.assertEqual(qlinear.zero_point, loaded.zero_point)
# Test JIT
self.checkScriptable(qlinear, zip([X_q], [Z_ref]), check_save_load=True)
def test_quant_dequant_api(self):
r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float)
scale, zero_point, dtype = 1.0, 2, torch.qint8
# testing Quantize API
qr = torch.quantize_linear(r, scale, zero_point, dtype)
quant_m = nnq.Quantize(scale, zero_point, dtype)
qr2 = quant_m(r)
self.assertEqual(qr, qr2)
# testing Dequantize API
rqr = qr.dequantize()
dequant_m = nnq.DeQuantize()
rqr2 = dequant_m(qr2)
self.assertEqual(rqr, rqr2)
@given(
use_bias=st.booleans(),
use_fused=st.booleans(),
)
def test_conv_api(self, use_bias, use_fused):
"""Tests the correctness of the conv module.
The correctness is defined against the functional implementation.
"""
N, iC, H, W = 10, 10, 10, 3
oC, g, kH, kW = 16, 1, 3, 3
scale, zero_point = 1.0 / 255, 128
X = torch.randn(N, iC, H, W, dtype=torch.float32)
X = X.permute([0, 2, 3, 1]).contiguous()
qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8)
w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32)
qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8)
b = torch.randn(oC, dtype=torch.float32) if use_bias else None
qb = torch.quantize_linear(b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None
if use_fused:
conv_under_test = ConvReLU2d(in_channels=iC,
out_channels=oC,
kernel_size=(kH, kW),
stride=1,
padding=0,
dilation=1,
groups=g,
bias=use_bias,
padding_mode='zeros')
else:
conv_under_test = Conv2d(in_channels=iC,
out_channels=oC,
kernel_size=(kH, kW),
stride=1,
padding=0,
dilation=1,
groups=g,
bias=use_bias,
padding_mode='zeros')
conv_under_test.set_weight(qw)
conv_under_test.bias = qb
conv_under_test.scale = scale
conv_under_test.zero_point = zero_point
# Test members
self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
self.assertTrue(hasattr(conv_under_test, 'scale'))
self.assertTrue(hasattr(conv_under_test, 'zero_point'))
# Test properties
self.assertEqual(qw, conv_under_test.weight())
self.assertEqual(qb, conv_under_test.bias)
self.assertEqual(scale, conv_under_test.scale)
self.assertEqual(zero_point, conv_under_test.zero_point)
# Test forward
result_under_test = conv_under_test(qX)
result_reference = qF.conv2d(qX, qw, bias=qb,
scale=scale, zero_point=zero_point,
stride=1, padding=0,
dilation=1, groups=g, dtype=torch.quint8
)
if use_fused:
# result_reference < zero_point doesn't work for qtensor yet
# result_reference[result_reference < zero_point] = zero_point
MB, OC, OH, OW = result_reference.size()
for i in range(MB):
for j in range(OC):
for h in range(OH):
for w in range(OW):
if result_reference[i][j][h][w].int_repr() < zero_point:
# assign 0. that gets converted to zero_point
result_reference[i][j][h][w] = 0.
self.assertEqual(result_reference, result_under_test,
message="Tensors are not equal.")
# Test serialization of quantized Conv Module using state_dict
model_dict = conv_under_test.state_dict()
self.assertEqual(model_dict['weight'], qw)
if use_bias:
self.assertEqual(model_dict['bias'], qb)
with tempfile.NamedTemporaryFile() as f:
torch.save(model_dict, f)
f.seek(0)
loaded_dict = torch.load(f)
for key in model_dict:
self.assertEqual(loaded_dict[key], model_dict[key])
if use_fused:
loaded_conv_under_test = ConvReLU2d(in_channels=iC,
out_channels=oC,
kernel_size=(kH, kW),
stride=1,
padding=0,
dilation=1,
groups=g,
bias=use_bias,
padding_mode='zeros')
else:
loaded_conv_under_test = Conv2d(in_channels=iC,
out_channels=oC,
kernel_size=(kH, kW),
stride=1,
padding=0,
dilation=1,
groups=g,
bias=use_bias,
padding_mode='zeros')
loaded_conv_under_test.load_state_dict(loaded_dict)
self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight())
if use_bias:
self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias)
self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale)
self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point)
self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test))
self.assertTrue(hasattr(conv_under_test, '_packed_weight'))
self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight'))
self.assertTrue(hasattr(conv_under_test, 'weight'))
self.assertTrue(hasattr(loaded_conv_under_test, 'weight'))
self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight())
self.assertEqual(loaded_conv_under_test.weight(), qw)
loaded_result = loaded_conv_under_test(qX)
self.assertEqual(loaded_result, result_reference)
with tempfile.NamedTemporaryFile() as f:
torch.save(conv_under_test, f)
f.seek(0)
loaded_conv = torch.load(f)
self.assertEqual(conv_under_test.bias, loaded_conv.bias)
self.assertEqual(conv_under_test.scale, loaded_conv.scale)
self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point)
# JIT testing
self.checkScriptable(conv_under_test, zip([qX], [result_reference]), check_save_load=True)
def test_pool_api(self):
"""Tests the correctness of the pool module.
The correctness is defined against the functional implementation.
"""
N, C, H, W = 10, 10, 10, 3
kwargs = {
'kernel_size': 2,
'stride': None,
'padding': 0,
'dilation': 1
}
scale, zero_point = 1.0 / 255, 128
X = torch.randn(N, C, H, W, dtype=torch.float32)
qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point,
dtype=torch.quint8)
qX_expect = torch.nn.functional.max_pool2d(qX, **kwargs)
pool_under_test = torch.nn.quantized.MaxPool2d(**kwargs)
qX_hat = pool_under_test(qX)
self.assertEqual(qX_expect, qX_hat)
# JIT Testing
self.checkScriptable(pool_under_test, zip([X], [qX_expect]))
if __name__ == '__main__':
run_tests()