blob: 6211f319af247c83b0f5bf5c6b9893f875f2f447 [file] [log] [blame]
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
from torch.quantization import get_default_qconfig, default_dynamic_qconfig
import torch.nn.quantized as nnq
toq = torch.ops.quantized
from torch.quantization._numeric_suite_fx import (
remove_qconfig_observer_fx,
compare_model_outputs_fx,
compare_weights_fx,
compare_model_stub_fx,
)
from torch.quantization.fx.quantize import is_activation_post_process
from torch.quantization.quantize_fx import (
convert_fx,
fuse_fx,
prepare_fx,
prepare_qat_fx,
)
from torch.testing._internal.common_quantization import (
ConvBnModel,
ConvBnReLUModel,
ConvModel,
QuantizationTestCase,
SingleLayerLinearDynamicModel,
SingleLayerLinearModel,
LSTMwithHiddenDynamicModel,
SparseNNModel,
skip_if_no_torchvision,
test_only_eval_fn,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing._internal.common_quantized import override_qengines
from torch.quantization.ns.graph_matcher import (
get_matching_subgraph_pairs,
GraphMatchingException,
)
from torch.quantization.ns.numeric_suite_core_apis_fx import (
compare_weights,
prepare_model_outputs,
OutputLogger,
prepare_model_with_stubs,
get_matching_activations,
get_matching_activations_a_shadows_b,
)
class TestGraphModeNumericSuite(QuantizationTestCase):
@override_qengines
def test_remove_qconfig_observer_fx(self):
r"""Remove activation_post_process node from fx prepred model"""
float_model = SingleLayerLinearModel()
float_model.eval()
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
prepared_float_model.eval()
model = remove_qconfig_observer_fx(prepared_float_model)
modules = dict(model.named_modules())
for node in model.graph.nodes:
if node.op == "call_module":
self.assertFalse(is_activation_post_process(modules[node.target]))
def compare_and_validate_model_weights_results_fx(
self, prepared_float_model, q_model, expected_weight_dict_keys
):
weight_dict = compare_weights_fx(
prepared_float_model.state_dict(), q_model.state_dict()
)
self.assertTrue(weight_dict.keys() == expected_weight_dict_keys)
self.assertEqual(len(weight_dict), 1)
for k, v in weight_dict.items():
self.assertTrue(v["float"].shape == v["quantized"].shape)
@override_qengines
def test_compare_weights_conv_static_fx(self):
r"""Compare the weights of float and static quantized conv layer"""
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
model_list = [ConvModel(), ConvBnModel(), ConvBnReLUModel()]
for float_model in model_list:
float_model.eval()
fused = fuse_fx(float_model)
prepared_model = prepare_fx(float_model, qconfig_dict)
# Run calibration
test_only_eval_fn(prepared_model, self.img_data_2d)
q_model = convert_fx(prepared_model)
expected_weight_dict_keys = {"conv.weight"}
self.compare_and_validate_model_weights_results_fx(
fused, q_model, expected_weight_dict_keys
)
@override_qengines
def test_compare_weights_linear_static_fx(self):
r"""Compare the weights of float and static quantized linear layer"""
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
float_model = SingleLayerLinearModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
prepared_float_model.eval()
# Run calibration
test_only_eval_fn(prepared_model, self.calib_data)
q_model = convert_fx(prepared_model)
expected_weight_dict_keys = {"fc1._packed_params._packed_params"}
self.compare_and_validate_model_weights_results_fx(
prepared_float_model, q_model, expected_weight_dict_keys
)
@override_qengines
def test_compare_weights_linear_dynamic_fx(self):
r"""Compare the weights of float and dynamic quantized linear layer"""
qconfig_dict = {"object_type": [(nn.Linear, default_dynamic_qconfig)]}
float_model = SingleLayerLinearDynamicModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
prepared_float_model.eval()
q_model = convert_fx(prepared_model)
expected_weight_dict_keys = {"fc1._packed_params._packed_params"}
self.compare_and_validate_model_weights_results_fx(
prepared_float_model, q_model, expected_weight_dict_keys
)
@override_qengines
def test_compare_weights_lstm_dynamic_fx(self):
r"""Compare the weights of float and dynamic quantized lstm layer"""
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
float_model = LSTMwithHiddenDynamicModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
prepared_float_model.eval()
q_model = convert_fx(prepared_model)
expected_weight_dict_keys = {"lstm._all_weight_values.0.param"}
self.compare_and_validate_model_weights_results_fx(
prepared_float_model, q_model, expected_weight_dict_keys
)
# TODO: Add submodule and functional test cases for compare_model_stub_fx
def compare_and_validate_model_stub_results_fx(
self,
prepared_float_model,
q_model,
module_swap_list,
expected_ob_dict_keys,
*data,
):
ob_dict = compare_model_stub_fx(
prepared_float_model, q_model, module_swap_list, *data
)
self.assertTrue(expected_ob_dict_keys == ob_dict.keys())
self.assertEqual(len(ob_dict), 1)
for k, v in ob_dict.items():
self.assertTrue(len(v["float"]) == len(v["quantized"]))
for i, val in enumerate(v["quantized"]):
self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
@override_qengines
def test_compare_model_stub_conv_static_fx(self):
r"""Compare the output of static quantized conv layer and its float shadow module"""
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
model_list = [ConvModel(), ConvBnReLUModel()]
for float_model in model_list:
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
# Run calibration
test_only_eval_fn(prepared_model, self.img_data_2d)
q_model = convert_fx(prepared_model)
module_swap_list = [nn.Conv2d, nni.modules.fused.ConvReLU2d]
expected_ob_dict_keys = {"conv.stats"}
self.compare_and_validate_model_stub_results_fx(
prepared_float_model,
q_model,
module_swap_list,
expected_ob_dict_keys,
self.img_data_2d[0][0],
)
@override_qengines
def test_compare_model_stub_linear_static_fx(self):
r"""Compare the output of static quantized linear layer and its float shadow module"""
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
float_model = SingleLayerLinearModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
# Run calibration
test_only_eval_fn(prepared_model, self.calib_data)
q_model = convert_fx(prepared_model)
linear_data = self.calib_data[0][0]
module_swap_list = [nn.Linear]
expected_ob_dict_keys = {"fc1.stats"}
self.compare_and_validate_model_stub_results_fx(
prepared_float_model,
q_model,
module_swap_list,
expected_ob_dict_keys,
linear_data,
)
@override_qengines
def test_compare_model_stub_linear_dynamic_fx(self):
r"""Compare the output of dynamic quantized linear layer and its float shadow module"""
qconfig_dict = {"object_type": [(nn.Linear, default_dynamic_qconfig)]}
float_model = SingleLayerLinearDynamicModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
prepared_float_model.eval()
q_model = convert_fx(prepared_model)
linear_data = self.calib_data[0][0]
module_swap_list = [nn.Linear]
expected_ob_dict_keys = {"fc1.stats"}
self.compare_and_validate_model_stub_results_fx(
prepared_float_model,
q_model,
module_swap_list,
expected_ob_dict_keys,
linear_data,
)
@override_qengines
def test_compare_model_stub_lstm_dynamic_fx(self):
r"""Compare the output of dynamic quantized linear layer and its float shadow module"""
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
float_model = LSTMwithHiddenDynamicModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
prepared_float_model.eval()
q_model = convert_fx(prepared_model)
module_swap_list = [nn.LSTM]
lstm_input = torch.rand((1, 1, 2))
lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
expected_ob_dict_keys = {"lstm.stats"}
self.compare_and_validate_model_stub_results_fx(
prepared_float_model,
q_model,
module_swap_list,
expected_ob_dict_keys,
lstm_input,
lstm_hidden,
)
def compare_and_validate_model_outputs_results_fx(
self, prepared_float_model, q_model, expected_act_compare_dict_keys, *data
):
act_compare_dict = compare_model_outputs_fx(
prepared_float_model, q_model, *data
)
self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
for k, v in act_compare_dict.items():
self.assertTrue(len(v["float"]) == 1)
self.assertTrue(len(v["float"]) == len(v["quantized"]))
for i, val in enumerate(v["quantized"]):
if "lstm.stats" not in act_compare_dict:
self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
else:
self.assertTrue(
v["float"][i][0].shape == v["quantized"][i][0].shape
)
if i == 1:
self.assertTrue(
v["float"][i][1].shape == v["quantized"][i][1].shape
)
@override_qengines
def test_compare_model_outputs_conv_static_fx(self):
r"""Compare the output of conv layer in static quantized model and corresponding
output of conv layer in float model
"""
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
model_list = [ConvModel(), ConvBnReLUModel()]
for float_model in model_list:
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
# Run calibration
test_only_eval_fn(prepared_model, self.img_data_2d)
q_model = convert_fx(prepared_model)
expected_act_compare_dict_keys = {"x.stats", "conv.stats"}
self.compare_and_validate_model_outputs_results_fx(
prepared_float_model,
q_model,
expected_act_compare_dict_keys,
self.img_data_2d[0][0],
)
@override_qengines
def test_compare_model_outputs_linear_static_fx(self):
r"""Compare the output of linear layer in static quantized model and corresponding
output of linear layer in float model
"""
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
float_model = SingleLayerLinearModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
# Run calibration
test_only_eval_fn(prepared_model, self.calib_data)
q_model = convert_fx(prepared_model)
linear_data = self.calib_data[0][0]
expected_act_compare_dict_keys = {"x.stats", "fc1.stats"}
self.compare_and_validate_model_outputs_results_fx(
prepared_float_model, q_model, expected_act_compare_dict_keys, linear_data
)
@override_qengines
def test_compare_model_outputs_linear_dynamic_fx(self):
r"""Compare the output of linear layer in dynamic quantized model and corresponding
output of linear layer in float model
"""
qconfig_dict = {"object_type": [(nn.Linear, default_dynamic_qconfig)]}
float_model = SingleLayerLinearDynamicModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
q_model = convert_fx(prepared_model)
linear_data = self.calib_data[0][0]
expected_act_compare_dict_keys = {"x.stats", "fc1.stats"}
self.compare_and_validate_model_outputs_results_fx(
prepared_float_model, q_model, expected_act_compare_dict_keys, linear_data
)
@override_qengines
def test_compare_model_outputs_lstm_dynamic_fx(self):
r"""Compare the output of LSTM layer in dynamic quantized model and corresponding
output of linear layer in float model
"""
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
float_model = LSTMwithHiddenDynamicModel()
float_model.eval()
prepared_model = prepare_fx(float_model, qconfig_dict)
prepared_float_model = copy.deepcopy(prepared_model)
q_model = convert_fx(prepared_model)
lstm_input = torch.rand((1, 1, 2))
lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
expected_act_compare_dict_keys = {"x.stats", "hid.stats", "lstm.stats"}
self.compare_and_validate_model_outputs_results_fx(
prepared_float_model,
q_model,
expected_act_compare_dict_keys,
lstm_input,
lstm_hidden,
)
class TestFXGraphMatcher(QuantizationTestCase):
@override_qengines
def test_simple_mod(self):
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.Conv2d_0':
((nn.Conv2d, nn.Conv2d), (nnq.Conv2d, nnq.Conv2d)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.Tensor(1, 4))
self.b = nn.Parameter(torch.zeros(1))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
def forward(self, x):
return F.linear(x, self.w, self.b)
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.functional.linear_0':
((F.linear, F.linear), (toq.linear, toq.linear))
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_fusion(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.Tensor(4, 1))
self.b = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w, self.b)
x = F.relu(x)
return x
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.functional.linear_0':
((F.linear, F.relu), (toq.linear_relu, toq.linear_relu)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_mod_multi(self):
m = nn.Sequential(
nn.Sequential(
nn.Conv2d(1, 1, 1),
),
nn.Conv2d(1, 1, 1),
).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
def test_simple_tensor_ops(self):
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
z = x + y
return z
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
def test_matching_failure_node_count(self):
# verify that matching graphs with matching node types but
# different counts of matchable nodes fails
m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig})
mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig})
with self.assertRaises(GraphMatchingException) as ex:
results = get_matching_subgraph_pairs(mp1, mp2)
@override_qengines
def test_matching_failure_node_type(self):
# verify that matching graphs with non-matching node types fails
m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
m2 = nn.Sequential(nn.Linear(1, 1)).eval()
mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig})
mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig})
with self.assertRaises(GraphMatchingException) as ex:
results = get_matching_subgraph_pairs(mp1, mp2)
@override_qengines
def test_nodes_before_cat(self):
# verify that nodes before cat get matched
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0):
x1 = torch.add(x0, 1.0)
y1 = torch.add(x0, 1.0)
x2 = torch.cat([x1, y1])
return x2
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.cat_0': ((torch.cat, torch.cat), (toq.cat, toq.cat)),
'base_op_torch.add_0': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_1': ((torch.add, torch.add), (toq.add, toq.add)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_dict_return_type(self):
# verify that we can traverse up nodes which return dictionaries
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0):
x1 = torch.add(x0, 1.0)
y1 = torch.add(x0, 1.0)
z1 = torch.add(x0, 1.0)
a1 = {'x1': x1, 'y1': (y1,), 'z1': [{'key': (z1,)}]}
return a1
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.add_0': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_1': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_2': ((torch.add, torch.add), (toq.add, toq.add)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_nodes_with_equal_types_do_not_get_matched(self):
# verifies that by default, nodes with equivalent types do not get matched.
# This is important for user defined types, for which we do not know
# the weight extraction functions or input type. In the future, this can
# be made configurable.
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = torch.mul(x, x)
x = torch.sigmoid(x)
x = F.relu(x)
return x
m = M().eval()
# prevent conv2 from getting quantized, so we can test
# modules with equal types
qconfig_dict = {
'': torch.quantization.default_qconfig,
'module_name': [('conv2', None)],
}
mp = prepare_fx(m, qconfig_dict)
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
# Conv2 should not be matched because we disabled quantization for it,
# so its type is the same in mp and mq. sigmoid and relu should not be
# matched because they use the same function in mp and mq.
expected_types = {
'base_op_torch.nn.Conv2d_0':
((nn.Conv2d, nn.Conv2d), (nnq.Conv2d, nnq.Conv2d)),
'base_op_torch.mul_0': ((torch.mul, torch.mul), (toq.mul, toq.mul)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
class TestFXGraphMatcherModels(QuantizationTestCase):
@override_qengines
@skip_if_no_torchvision
def test_mobilenet_v2(self):
# verify that mobilenetv2 graph is able to be matched
import torchvision
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
@skip_if_no_torchvision
def test_mobilenet_v2_qat(self):
# verify that mobilenetv2 graph is able to be matched
import torchvision
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float()
mp = prepare_qat_fx(m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
class TestFXNumericSuiteCoreAPIs(QuantizationTestCase):
@override_qengines
def test_compare_weights_mod(self):
m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = compare_weights('fp32_prepared', mp, 'int8', mq)
self.assertTrue(len(results) == 2)
self.assert_ns_compare_dict_valid(results)
@override_qengines
def test_compare_weights_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.Tensor(4, 4))
self.b = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w, self.b)
x = F.relu(x)
x = F.linear(x, self.w, self.b)
return x
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
mp(torch.randn(1, 4))
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = compare_weights('fp32_prepared', mp, 'int8', mq)
self.assertTrue(len(results) == 2)
self.assert_ns_compare_dict_valid(results)
@override_qengines
def test_match_activations_mod(self):
m = nn.Sequential(
torch.quantization.QuantStub(),
nn.Conv2d(1, 1, 1),
nn.Conv2d(1, 1, 1),
).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
mp(torch.randn(2, 1, 2, 2))
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_ns, mq_ns = prepare_model_outputs(
'fp32_prepared', mp, 'int8', mq, OutputLogger)
expected_occurrence = {
ns.call_module(OutputLogger): 2,
}
self.checkGraphModuleNodes(
mp_ns, expected_node_occurrence=expected_occurrence)
self.checkGraphModuleNodes(
mq_ns, expected_node_occurrence=expected_occurrence)
# TODO(before land): test both scripted and non-scripted
mp_ns = torch.jit.script(mp_ns)
mq_ns = torch.jit.script(mq_ns)
# calibrate
input_fp32 = torch.randn(2, 1, 2, 2)
mp_ns(input_fp32)
mq_ns(input_fp32)
# check activation result correctness
act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
self.assertTrue(len(act_compare_dict) == 2)
self.assert_ns_compare_dict_valid(act_compare_dict)
@override_qengines
def test_match_activations_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.Tensor(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
self.w2 = nn.Parameter(torch.Tensor(4, 4))
self.b2 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w1, self.b1)
x = F.linear(x, self.w2, self.b2)
x = F.relu(x)
return x
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
mp(torch.randn(4, 4))
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_ns, mq_ns = prepare_model_outputs(
'fp32_prepared', mp, 'int8', mq, OutputLogger)
expected_occurrence = {
ns.call_module(OutputLogger): 2,
}
self.checkGraphModuleNodes(
mp_ns, expected_node_occurrence=expected_occurrence)
self.checkGraphModuleNodes(
mq_ns, expected_node_occurrence=expected_occurrence)
# TODO(before land): test both scripted and non-scripted
mp_ns = torch.jit.script(mp_ns)
mq_ns = torch.jit.script(mq_ns)
# calibrate
input_fp32 = torch.randn(4, 4)
mp_ns(input_fp32)
mq_ns(input_fp32)
# check activation result correctness
act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
self.assertTrue(len(act_compare_dict) == 2)
self.assert_ns_compare_dict_valid(act_compare_dict)
@override_qengines
def test_prepare_model_with_stubs_mod(self):
m = nn.Sequential(
nn.Conv2d(1, 1, 1),
nn.Conv2d(1, 1, 1),
).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
mp(torch.randn(1, 1, 4, 4))
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_shadows_mq = prepare_model_with_stubs('fp32_prepared', mp, 'int8', mq, OutputLogger)
# TODO(before land): test both scripted and non-scripted
mp_shadows_mq = torch.jit.script(mp_shadows_mq)
# calibrate
input_fp32 = torch.randn(1, 1, 4, 4)
mp_shadows_mq(input_fp32)
# check activation result correctness
act_compare_dict = get_matching_activations_a_shadows_b(
mp_shadows_mq, OutputLogger)
self.assertTrue(len(act_compare_dict) == 2)
self.assert_ns_compare_dict_valid(act_compare_dict)
@override_qengines
def test_prepare_model_with_stubs_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.Tensor(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
self.w2 = nn.Parameter(torch.Tensor(4, 4))
self.b2 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w1, self.b1)
x = F.linear(x, self.w2, self.b2)
x = F.relu(x)
return x
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
mp(torch.randn(4, 4))
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_shadows_mq = prepare_model_with_stubs('fp32_prepared', mp, 'int8', mq, OutputLogger)
# TODO(before land): test both scripted and non-scripted
mp_shadows_mq = torch.jit.script(mp_shadows_mq)
# calibrate
input_fp32 = torch.randn(4, 4)
mp_shadows_mq(input_fp32)
# check activation result correctness
act_compare_dict = get_matching_activations_a_shadows_b(
mp_shadows_mq, OutputLogger)
self.assertTrue(len(act_compare_dict) == 2)
self.assert_ns_compare_dict_valid(act_compare_dict)
class TestFXNumericSuiteCoreAPIsModels(QuantizationTestCase):
"""
Tests numeric suite core APIs on non-toy models.
"""
@override_qengines
def test_sparsenn_compare_activations(self):
sparse_nn = SparseNNModel().eval()
# quantize the embeddings and the dense part separately, using FX graph mode
sparse_nn.dense_top = prepare_fx(
sparse_nn.dense_top,
{'': torch.quantization.default_qconfig},
)
# calibrate
idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.LongTensor([0, 4])
x = torch.randn(2, 4)
sparse_nn(idx, offsets, x)
# convert
sparse_nn_q = copy.deepcopy(sparse_nn)
sparse_nn_q.dense_top = convert_fx(sparse_nn_q.dense_top)
# test out compare activations API
sparse_nn.dense_top, sparse_nn_q.dense_top = prepare_model_outputs(
'fp32_prepared', sparse_nn.dense_top, 'int8', sparse_nn_q.dense_top, OutputLogger)
# calibrate
sparse_nn(idx, offsets, x)
sparse_nn_q(idx, offsets, x)
# inspect results
act_compare_dict = get_matching_activations(
sparse_nn, sparse_nn_q, OutputLogger)
self.assertTrue(len(act_compare_dict) == 4)
self.assert_ns_compare_dict_valid(act_compare_dict)
@override_qengines
def test_sparsenn_shadow(self):
sparse_nn = SparseNNModel().eval()
# quantize the embeddings and the dense part separately, using FX graph mode
sparse_nn.dense_top = prepare_fx(
sparse_nn.dense_top,
{'': torch.quantization.default_qconfig},
)
# calibrate
idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.LongTensor([0, 4])
x = torch.randn(2, 4)
sparse_nn(idx, offsets, x)
# convert
sparse_nn_q = copy.deepcopy(sparse_nn)
sparse_nn_q.dense_top = convert_fx(sparse_nn_q.dense_top)
# test out compare shadow activations API
sparse_nn_q.dense_top = prepare_model_with_stubs(
'fp32_prepared', sparse_nn.dense_top,
'int8', sparse_nn_q.dense_top, OutputLogger)
# calibrate
sparse_nn_q(idx, offsets, x)
# check activation result correctness
act_compare_dict = get_matching_activations_a_shadows_b(
sparse_nn_q, OutputLogger)
self.assertTrue(len(act_compare_dict) == 4)
self.assert_ns_compare_dict_valid(act_compare_dict)