blob: 54185298e0f425500997e73f8c62d0c68ee91047 [file] [log] [blame]
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.nn.quantized.functional as F
import torch.nn.quantized as nnq
from common_utils import TestCase, run_tests, tempfile
'''
Note that tests in this file are just API test, to make sure we wrapped the
quantized operator implementations correctly in the user facing APIs, these are
not correctness test for the underlying quantized operators. For correctness
test please see `caffe2/test/test_quantized.py`.
'''
class FunctionalAPITest(TestCase):
def test_relu_api(self):
X = torch.arange(-5, 5, dtype=torch.float)
scale = 2.0
zero_point = 1
qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
qY = torch.ops.quantized.relu(qX)
qY_hat = F.relu(qX)
self.assertEqual(qY, qY_hat)
class ModuleAPITest(TestCase):
def test_linear_api(self):
"""test API functionality for nn.quantized.linear"""
in_features = 10
out_features = 20
batch_size = 5
W = torch.rand(out_features, in_features).float()
W_q = torch.quantize_linear(W, 0.1, 4, torch.qint8)
W_pack = torch.ops.quantized.fbgemm_linear_prepack(W_q)
X = torch.rand(batch_size, in_features).float()
X_q = torch.quantize_linear(X, 0.2, 10, torch.quint8)
B = torch.rand(out_features).float()
B_q = torch.quantize_linear(B, W_q.q_scale() * X_q.q_scale(), 0, torch.qint32)
out_scale = 0.5
out_zero_point = 3
qlinear = nnq.Linear(in_features, out_features)
qlinear._packed_weight = W_pack
qlinear.bias = B_q
qlinear.out_scale = torch.tensor([out_scale])
qlinear.out_zero_point = torch.tensor([out_zero_point])
Z_q = qlinear(X_q)
# Check if the module implementation matches calling the
# ops directly
Z_ref = torch.ops.quantized.fbgemm_linear(X_q, W_pack, B_q, out_scale, out_zero_point)
self.assertEqual(Z_ref, Z_q)
# Test serialization of quantized Linear Module using state_dict
model_dict = qlinear.state_dict()
self.assertEqual(model_dict['weight'], W_q)
self.assertEqual(model_dict['bias'], B_q)
with tempfile.NamedTemporaryFile() as f:
torch.save(model_dict, f)
f.seek(0)
loaded_dict = torch.load(f)
for key in model_dict:
self.assertEqual(model_dict[key], loaded_dict[key])
loaded_qlinear = nnq.Linear(in_features, out_features)
loaded_qlinear.load_state_dict(loaded_dict)
linear_unpack = torch.ops.quantized.fbgemm_linear_unpack
self.assertEqual(linear_unpack(qlinear._packed_weight),
linear_unpack(loaded_qlinear._packed_weight))
self.assertEqual(qlinear.bias, loaded_qlinear.bias)
self.assertEqual(qlinear.out_scale, loaded_qlinear.out_scale)
self.assertEqual(qlinear.out_zero_point, loaded_qlinear.out_zero_point)
self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
self.assertTrue(hasattr(qlinear, '_packed_weight'))
self.assertTrue(hasattr(loaded_qlinear, '_packed_weight'))
self.assertTrue(hasattr(qlinear, 'weight'))
self.assertTrue(hasattr(loaded_qlinear, 'weight'))
self.assertEqual(qlinear.weight, loaded_qlinear.weight)
self.assertEqual(qlinear.weight, torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight))
Z_q2 = qlinear(X_q)
self.assertEqual(Z_q, Z_q2)
# test serialization of module directly - will add this later
# with tempfile.NamedTemporaryFile() as f:
# torch.save(qLinear, f)
# f.seek(0)
# loaded = torch.load(f)
# state = qLinear.__getstate__()
# compareUnpackedWeight(qLinear._packed_weight, loaded._packed_weight)
# self.assertEqual(qLinear.bias, loaded.bias)
# self.assertEqual(qLinear.out_scale, loaded.out_scale)
# self.assertEqual(qLinear.out_zero_point, loaded.out_zero_point)
def test_quant_dequant_api(self):
r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float)
scale, zero_point, dtype = 1.0, 2, torch.qint8
# testing Quantize API
qr = torch.quantize_linear(r, scale, zero_point, dtype)
quant_m = nnq.Quantize(scale, zero_point, dtype)
qr2 = quant_m(r)
self.assertEqual(qr, qr2)
# testing Dequantize API
rqr = qr.dequantize()
dequant_m = nnq.DeQuantize()
rqr2 = dequant_m(qr2)
self.assertEqual(rqr, rqr2)
if __name__ == '__main__':
run_tests()