blob: 361d7d0146651e82c727d21d067f031e31d4ecd6 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import torch.jit
from jit_utils import _tmp_donotuse_dont_inline_everything
from torch._jit_internal import Optional
import torch.nn as nn
from common_utils import TestCase, run_tests
from common_quantization import NestedModel, AnnotatedNestedModel
from torch.quantization import QuantStub, DeQuantStub, \
quantize, default_eval_fn, QConfig
class Observer(torch.nn.Module):
__annotations__ = {'scale' : Optional[torch.Tensor], 'zero_point': Optional[torch.Tensor]}
def __init__(self):
super(Observer, self).__init__()
self.dtype = torch.quint8
self.qscheme = torch.per_tensor_affine
self.scale, self.zero_point = None, None
def forward(self, x):
self.scale = torch.tensor([2.0])
self.zero_point = torch.tensor([3])
return x
@torch.jit.export
def calculate_qparams(self):
return self.scale, self.zero_point
class WeightObserver(Observer):
def __init__(self):
super(WeightObserver, self).__init__()
self.dtype = torch.qint8
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
@unittest.skip("temoprarily disable the test")
class QuantizerTestCase(TestCase):
@_tmp_donotuse_dont_inline_everything
def test_default(self):
class TestM(nn.Module):
def __init__(self, qconfig):
super(TestM, self).__init__()
self.conv = nn.Conv2d(3, 1, 3).float()
self.conv.weight.data.fill_(1.0)
self.conv.bias.data.fill_(0.01)
self.qconfig = qconfig
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
return self.dequant(self.conv(self.quant(x)))
class TestScriptM(torch.jit.ScriptModule):
def __init__(self):
super(TestScriptM, self).__init__()
self.conv = nn.Conv2d(3, 1, 3).float()
self.conv.bias.data.fill_(0.01)
@torch.jit.script_method
def forward(self, x):
y = self.conv(x)
return y
# Test Data
data = [(torch.randn(10, 3, 10, 10, dtype=torch.float), 1)]
# Eager mode
fake_qconfig = QConfig(activation=Observer, weight=WeightObserver)
eager_module = TestM(fake_qconfig)
# Script mode
script_module = TestScriptM()
script_module.conv.weight = torch.nn.Parameter(eager_module.conv.weight.detach())
quantized_eager_module = quantize(eager_module, default_eval_fn, data)
def get_forward(m):
return m._c._get_method('forward')
# TODO: test jit.script as well
ScriptedObserver = torch.jit.script(Observer())
ScriptedWeightObserver = torch.jit.script(WeightObserver())
qconfig_dict = {
'':
QConfig(
activation=ScriptedObserver._c,
weight=ScriptedWeightObserver._c)
}
torch._C._jit_pass_insert_observers(script_module._c,
"forward",
qconfig_dict)
# Run ScriptM Model and Collect statistics
get_forward(script_module)(data[0][0])
# Insert quantize and dequantize calls
script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward")
# Note that observer modules are not removed right now
torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph)
get_forward(script_module)(data[0][0])
eager_result = quantized_eager_module(data[0][0])
script_result = get_forward(script_module)(data[0][0])
self.assertEqual(eager_result, script_result)
@_tmp_donotuse_dont_inline_everything
def test_qconfig_dict(self):
data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)]
# Eager mode
qconfig = QConfig(activation=Observer, weight=WeightObserver)
eager_module = AnnotatedNestedModel()
eager_module.fc3.qconfig = qconfig
eager_module.sub2.fc1.qconfig = qconfig
# Assign weights
eager_module.sub1.fc.weight.data.fill_(1.0)
eager_module.sub2.fc1.module.weight.data.fill_(1.0)
eager_module.sub2.fc2.weight.data.fill_(1.0)
eager_module.fc3.module.weight.data.fill_(1.0)
script_module = torch.jit.script(NestedModel())
# Copy weights for eager_module
script_module.sub1.fc.weight = eager_module.sub1.fc.weight
script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight
script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight
script_module.fc3.weight = eager_module.fc3.module.weight
# Quantize eager module
quantized_eager_module = quantize(eager_module, default_eval_fn, data)
def get_forward(m):
return m._c._get_method('forward')
# Quantize script_module
torch._C._jit_pass_constant_propagation(get_forward(script_module).graph)
ScriptedObserver = torch.jit.script(Observer())
ScriptedWeightObserver = torch.jit.script(WeightObserver())
scripted_qconfig = QConfig(
activation=ScriptedObserver._c,
weight=ScriptedWeightObserver._c)
qconfig_dict = {
'sub2.fc1': scripted_qconfig,
'fc3': scripted_qconfig
}
torch._C._jit_pass_insert_observers(script_module._c,
"forward",
qconfig_dict)
# Run script_module and Collect statistics
get_forward(script_module)(data[0][0])
# Insert quantize and dequantize calls
script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward")
# Note that observer modules are not removed right now
torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph)
get_forward(script_module)(data[0][0])
eager_result = quantized_eager_module(data[0][0])
script_result = get_forward(script_module)(data[0][0])
self.assertEqual(eager_result, script_result)
if __name__ == '__main__':
run_tests()