blob: d1977b47346486d72110065661791e1e7f687272 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch.jit
import torch.nn as nn
import torch.nn.functional as F
from common_utils import TestCase
# TODO : Quantizer tests to be integrated with CI once quantizer intf hardened
r"""
Default Weight Observer:
Stats needed for accumulation
Arguments:
value: Tensor to be observed
stats: Computed stats. Injected by the observer
wrapper
Output:
stats: Modified stats
"""
def weightObserver(value, stats):
if stats is None:
stats = torch.zeros(2)
stats[0] = torch.min(value)
stats[1] = torch.max(value)
return stats
r"""
Default Activation Observer:
This implementation averages over collected stats.
Arguments:
value: Tensor to be observed
stats: Computed stats. Injected by the observer
wrapper
Output:
stats: Modified stats
"""
def activationObserver(value, stats):
if stats is None:
stats = torch.zeros(2)
averaging_constant = 0.001
stats[0] = (1 - averaging_constant) * stats[0] + \
averaging_constant * torch.min(value)
stats[1] = (1 - averaging_constant) * stats[1] + \
averaging_constant * torch.max(value)
return stats
r"""
Default QParam computation: This is stateless
value_stats will be input from Observer
Arguments:
name: Key name in the stats dictionary
wrapper
value_stats: Stats dict from observer wrapper
Output:
scale, zero_point
"""
def calcQParamFunc(name, value_stats):
scaleT = 2.0 * (torch.max(value_stats[name][1],
-value_stats[name][0]) / 255.0)
scale = scaleT.item()
zero_point = 0
return scale, zero_point
r"""
Unified Dictionary for all qparam
"""
def getAllQParamDict(allqparam_dict, quantObj):
if allqparam_dict is None:
allqparam_dict = {}
qparam_dict = quantObj.getQParamDict()
if qparam_dict is None:
return
allqparam_dict.update(qparam_dict)
r"""
This is an example QuantTemplate which will be used to collect
stats across batches by running torch script/trace module, from the
observer nodes inserted in the graph. These stats are used to compute
Quantization Parameters. These will be passed to quantizer to be used
as arguments for quant ops in quantization pass.
"""
class QuantTemplate:
def __init__(self, qscheme, observerImpl=None, calcQParamImpl=None):
self.value_stats = {}
self.qparam_dict = {}
self.averaging_constant = 0.001
self.observerImpl = observerImpl
self.calcQParamImpl = calcQParamImpl
self.qscheme = qscheme
def resetStats(self):
self.value_stats = {}
return
def observer(self, value, name):
if self.observerImpl is None:
return
if name not in self.value_stats:
self.value_stats[name] = []
stats = None
else:
stats = self.value_stats[name]
stats = self.observerImpl(value, stats)
self.value_stats.update({name: stats})
return value
def calcQParam(self):
self.qparam_dict = {}
if self.calcQParamImpl is None:
return
for name in self.value_stats:
# This can change depending on type of quantization which will
# be known to QuantTemplate object
scale, zero_point = self.calcQParamImpl(name, self.value_stats)
self.qparam_dict.update({name: (self.qscheme, scale, zero_point)})
def getQParam(self, name):
if name in self.qparam_dict:
return self.qparam_dict[name]
else:
return ()
def getQParamDict(self):
return self.qparam_dict
class QuantizerTestCase(TestCase):
def test_compare_qparam_eager_script_default(self):
# Simple test case with conv->relu->maxpool
class TestScriptM(torch.jit.ScriptModule):
def __init__(self, init_weight=None):
super(TestScriptM, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv1.weight.data.fill_(1.0)
self.conv1.bias.data.fill_(0.01)
@torch.jit.script_method
def forward(self, x):
y = F.relu(self.conv1(x))
z = F.max_pool2d(y, 2, 2)
return z
class TestM(nn.Module):
def __init__(self, quantObj=None):
super(TestM, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv1.weight.data.fill_(1.0)
self.conv1.bias.data.fill_(0.01)
self.quantObj = quantObj
def forward(self, x):
y = F.relu(self.conv1(x))
if self.quantObj is not None:
self.quantObj.observer(y, "y")
z = F.max_pool2d(y, 2, 2)
if self.quantObj is not None:
self.quantObj.observer(z, "z")
return z
# Test Data
data = torch.ones(1, 1, 28, 28)
# Eager mode
# Create QuantConfig object for eager mode
eagerQuantObj = QuantTemplate(qscheme='per_tensor_quant',
observerImpl=activationObserver,
calcQParamImpl=calcQParamFunc)
eagerM = TestM(quantObj=eagerQuantObj)
# Run EagerMode Model and Collect stats
eagerM.forward(data)
eagerM.quantObj.calcQParam()
# Script mode
scriptM = TestScriptM()
# Create QuantConfig object for script mode
activationQuantObj = QuantTemplate(qscheme='per_tensor_quant',
observerImpl=activationObserver,
calcQParamImpl=calcQParamFunc)
# This performs type analysis to identify tensors from other
# types. This info needed for further quantizer passes
torch._C._jit_pass_constant_propagation(scriptM.graph)
# Insert observers
torch._C._jit_pass_insert_observers(scriptM._c, "forward", activationQuantObj.observer)
# Run ScriptM Model and Collect statistics
scriptM.forward(data)
activationQuantObj.calcQParam()
# Compare results for eager and graph mode
eagerDict = eagerQuantObj.getQParamDict()
activationDict = activationQuantObj.getQParamDict()
# TODO - fix @eellison
self.assertTrue('z' in eagerDict and 'z.1' in activationDict)
self.assertAlmostEqual(eagerDict["z"][0], activationDict["z.1"][0], places=15)
self.assertAlmostEqual(eagerDict["z"][1], activationDict["z.1"][1], places=15)