blob: 2da8eea60a41d5466dc5050123e6d1adb28abe8c [file] [log] [blame]
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: quantization"]
import torch
import torch._C
from torch.ao.quantization import (
default_dynamic_qconfig,
per_channel_dynamic_qconfig,
)
from torch.ao.quantization.quantize_jit import (
prepare_dynamic_jit,
convert_dynamic_jit,
_prepare_ondevice_dynamic_jit,
_quantize_ondevice_dynamic_jit,
)
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_quantization import (
get_script_module,
LinearAddModel,
)
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
from torch.testing import FileCheck
from torch.utils import bundled_inputs as bundled_inputs
import io
from typing import Dict
class myMod(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.fc1 = torch.nn.Linear(5, 5).float()
self.fc1.weight = weight
self.fc2 = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc2(self.fc1(x))
class MyConvLinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
weight = torch.nn.Parameter(torch.ones(5, 5))
self.weight1 = torch.nn.Parameter(torch.ones(5, 5))
self.mymod = myMod(weight)
def forward(self, x):
conv_output = self.conv(x)
y = self.mymod(conv_output)
z = torch.nn.functional.linear(y, self.weight1)
return z
def get_example_inputs(self):
return (torch.rand(1, 3, 12, 7),)
class OnDevicePTQUtils:
observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver']
@staticmethod
def insert_observers(model, qconfig_dict):
inputs = model.get_example_inputs()
scripted_model = get_script_module(model, False, inputs)
scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
return scripted_model
@staticmethod
def ptq_dynamic_quantize(model, qconfig_dict):
inputs = model.get_example_inputs()
m = get_script_module(model, False, inputs)
m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True)
return m
@staticmethod
def find_observer_modules(m):
observer_modules = []
for child_module in m.children():
if child_module.original_name in OnDevicePTQUtils.observer_module_name:
observer_modules.append(child_module)
return observer_modules
@staticmethod
def is_value_type_observer(value):
type_name = value.type()
for observer_type in OnDevicePTQUtils.observer_module_name:
if observer_type in type_name.str():
return True
return False
@staticmethod
def is_calculate_qparam(node):
if node.kind() == "prim::CallMethod":
if node.s('name') == "calculate_qparams":
return True
return False
@staticmethod
def get_linear_packed_param_fp_weight(node):
weight = node.inputsAt(0).node()
if weight.kind() != "aten::quantize_per_tensor" and weight.kind() != "aten::quantize_per_channel":
raise ValueError("Quantized weight must be produced.")
fp_weight = weight.inputsAt(0).node()
assert fp_weight.kind() == "prim::GetAttr", "Weight must be an attribute of the module."
fp_weight_name = fp_weight.s('name')
return fp_weight_name
@staticmethod
def is_per_channel_quantized_packed_param(node):
assert node.kind() == 'quantized::linear_prepack', "Node must corresponds to linear_prepack."
weight = node.inputsAt(0).node()
assert weight.kind() != "aten::quantize_per_tensor" or weight.kind() != "aten::quantize_per_channel"
return weight.kind() != "aten::quantize_per_tensor"
class TestOnDeviceDynamicPTQInsertObservers(TestCase):
def _check_num_and_type_of_observers(self, model, num_observers):
qconfig_dict = {"": default_dynamic_qconfig}
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
self.assertTrue(observer.original_name == 'MinMaxObserver')
qconfig_dict = {"": per_channel_dynamic_qconfig}
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
self.assertTrue(observer.original_name == 'PerChannelMinMaxObserver')
def _check_observer_method(self, model, num_observers):
qconfig_dict = {"": default_dynamic_qconfig}
inputs = model.get_example_inputs()
orig_scripted_model = get_script_module(model, False, inputs)
torch._C._jit_pass_inline(orig_scripted_model.graph)
orig_forward_graph = orig_scripted_model.graph.str()
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
quant_forward_graph = scripted_model.graph.str()
# exact graph matching is difficult so just resorting to # of lines
# instead of implementing graph matching
self.assertEqual(len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines()))
observe_method = scripted_model.observe_forward.graph
FileCheck().check_count("prim::CallMethod[name=\"forward\"](%_observer",
num_observers, exactly=True).run(observe_method)
reset_observers_method = scripted_model.reset_observers_forward.graph
FileCheck().check_count(
"prim::CallMethod[name=\"reset_min_max_vals\"](%_observer", num_observers, exactly=True).run(reset_observers_method)
def _observer_is_weight_only(self, node):
if (node.kind() == "prim::CallMethod") and node.s("name") == "forward":
if (OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0))):
return (node.inputsAt(1).node().kind() == "prim::GetAttr")
return False
def test_num_observers(self):
model = LinearAddModel()
self._check_num_and_type_of_observers(model, 2)
model = MyConvLinearModule()
self._check_num_and_type_of_observers(model, 3)
def test_observe_method(self):
model = MyConvLinearModule()
self._check_observer_method(model, 3)
def test_weight_only_observers(self):
model = MyConvLinearModule()
qconfig_dict = {"": default_dynamic_qconfig}
inputs = model.get_example_inputs()
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observe_forward_graph = scripted_model.observe_forward.graph
num_weight_only_observers = 0
for node in observe_forward_graph.nodes():
if (self._observer_is_weight_only(node)):
num_weight_only_observers += 1
self.assertEqual(num_weight_only_observers, 3)
class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
quantize_forward_graph = model.quantize_forward.graph
quantize_per_tensor = quantize_per_channel = 0
for n in quantize_forward_graph.nodes():
if "aten::quantize_per_tensor" in n.kind():
quantize_per_tensor += 1
if "aten::quantize_per_channel" in n.kind():
quantize_per_channel += 1
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
def _validate_calculate_qparams(self, model, num_nodes):
quantize_forward_graph = model.quantize_forward.graph
num_calculate_qparams = 0
for n in quantize_forward_graph.nodes():
if OnDevicePTQUtils.is_calculate_qparam(n):
num_calculate_qparams += 1
self.assertEqual(num_calculate_qparams, num_nodes)
def _validate_no_observer_forward(self, model):
quantize_forward_graph = model.quantize_forward.graph
for n in quantize_forward_graph.nodes():
if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))):
return False
return True
def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
def _check_quantize_forward_runs(self, model):
inputs = model.get_example_inputs()
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
# First must run observe forward to record the stats to produce
# correct scales and zero points
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
def test_num_quant_dequant_nodes(self):
model = LinearAddModel()
self._check_quant_dequant_and_calc_qparams(model, 2)
model = MyConvLinearModule()
self._check_quant_dequant_and_calc_qparams(model, 3)
def test_quantize_forward_runs(self):
model = LinearAddModel()
self._check_quantize_forward_runs(model)
model = MyConvLinearModule()
self._check_quantize_forward_runs(model)
class TestOnDeviceDynamicPTQFinalize(TestCase):
def _validate_packed_params(self, model, num_nodes, per_channel=0):
quantize_forward_graph = model.quantize_forward.graph
quantize_per_tensor = quantize_per_channel = 0
linear_prepack = 0
linear_prepack_uses = 0
for n in quantize_forward_graph.nodes():
if n.kind() == 'prim::SetAttr':
maybe_packed_param_value = n.inputsAt(1)
maybe_packed_param = maybe_packed_param_value.node()
if maybe_packed_param.kind() == 'quantized::linear_prepack':
linear_prepack += 1
linear_prepack_uses += len(maybe_packed_param_value.uses())
if OnDevicePTQUtils.is_per_channel_quantized_packed_param(maybe_packed_param):
quantize_per_channel += 1
else:
quantize_per_tensor += 1
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
self.assertEqual(quantize_per_channel, per_channel)
self.assertEqual(linear_prepack, num_nodes)
self.assertEqual(linear_prepack_uses, num_nodes)
def _validate_no_linear_unpack(self, model):
quantize_forward_graph = model.quantize_forward.graph
for n in quantize_forward_graph.nodes():
if n.kind() == 'quantized::linear_unpack':
return False
return True
def _validate_setattr_fp_weights(self, model, num_nodes):
quantize_forward_graph = model.quantize_forward.graph
fp_weights_setattr = 0
fp_weight_names = []
for n in quantize_forward_graph.nodes():
if n.kind() == 'prim::SetAttr':
maybe_packed_param = n.inputsAt(1).node()
if maybe_packed_param.kind() == 'quantized::linear_prepack':
weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight(maybe_packed_param)
fp_weight_names.append(weight_name)
for n in quantize_forward_graph.nodes():
# This is basically detecting
# %x = prim::Constant
# = prim::SetAttr(<weight_name>)(module_value, x)
# Thus making sure that the original fp weights are
# reset
if n.kind() == 'prim::SetAttr':
weight_name = n.s('name')
if weight_name in fp_weight_names:
maybe_constant = n.inputsAt(1).node()
if maybe_constant.kind() == 'prim::Constant':
fp_weights_setattr += 1
self.assertEqual(fp_weights_setattr, num_nodes)
def _validate_quantized_forward(self, model, num_nodes):
quantized_forward_graph = model.quantized_forward.graph
quantize_per_tensor = quantize_per_channel = 0
quantized_linear_dynamic = 0
linear_packed_params = 0
num_setattr = 0
for n in quantized_forward_graph.nodes():
if "aten::quantize_per_tensor" in n.kind():
quantize_per_tensor += 1
if "aten::quantize_per_channel" in n.kind():
quantize_per_channel += 1
if "quantized::linear_dynamic" in n.kind():
quantized_linear_dynamic += 1
if n.kind() == 'prim::GetAttr':
output = n.outputsAt(0)
output_type = output.type()
if "LinearPackedParamsBase" in output_type.str():
linear_packed_params += 1
if n.kind() == 'prim::SetAttr':
num_setattr += 1
self.assertEqual(quantize_per_tensor, 0)
self.assertEqual(quantize_per_channel, 0)
self.assertEqual(quantized_linear_dynamic, num_nodes)
self.assertEqual(linear_packed_params, num_nodes)
# self.assertEqual(num_setattr, 0)
def _check_quantize_forward(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_packed_params(m, num_nodes)
self._validate_no_linear_unpack(m)
self._validate_setattr_fp_weights(m, num_nodes)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_packed_params(m, num_nodes, num_nodes)
self._validate_no_linear_unpack(m)
self._validate_setattr_fp_weights(m, num_nodes)
def _check_quantized_forward(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quantized_forward(m, num_nodes)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quantized_forward(m, num_nodes)
def _check_against_ref_dynamic_ptq(self, model):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : default_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
thrown = False
try:
m(*inputs)
except Exception as e:
thrown = True
self.assertTrue(thrown)
# test with per channel quant
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
thrown = False
try:
m(*inputs)
except Exception as e:
thrown = True
self.assertTrue(thrown)
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : default_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
buffer = io.BytesIO()
torch.jit.save(ref_m, buffer)
buffer.seek(0)
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
# Now serialize to flabuffer and load from fb and check
dict: Dict[str, str] = {}
bytes = torch._C._save_mobile_module_to_bytes(m._c, dict)
m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes))
fb_output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, fb_output))
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
buffer = io.BytesIO()
torch.jit.save(ref_m, buffer)
buffer.seek(0)
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
def _check_serialization_deserialization(self, model):
self._check_serdes_and_device_side_api_helper(model, False)
def _check_device_side_api(self, model):
self._check_serdes_and_device_side_api_helper(model, True)
def test_quantize_forward(self):
model = LinearAddModel()
self._check_quantize_forward(model, 2)
model = MyConvLinearModule()
self._check_quantize_forward(model, 3)
def test_quantized_forward(self):
model = LinearAddModel()
self._check_quantized_forward(model, 2)
model = MyConvLinearModule()
self._check_quantized_forward(model, 3)
def test_against_offdevice_dynamic_ptq(self):
model = LinearAddModel()
self._check_against_ref_dynamic_ptq(model)
model = MyConvLinearModule()
self._check_against_ref_dynamic_ptq(model)
def test_serialization_deserialization(self):
model = MyConvLinearModule()
self._check_serialization_deserialization(model)
def test_device_side_api(self):
model = MyConvLinearModule()
self._check_device_side_api(model)