blob: 3a30f7999622dc22edeb20e7e910c2f30e6d739f [file] [log] [blame]
import copy
import math
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import default_dynamic_qconfig
import torch.nn.quantized as nnq
toq = torch.ops.quantized
from torch.quantization.quantize_fx import (
convert_fx,
prepare_fx,
prepare_qat_fx,
)
from torch.testing._internal.common_quantization import (
ConvBnModel,
ConvBnReLUModel,
ConvModel,
QuantizationTestCase,
skipIfNoFBGEMM,
SingleLayerLinearDynamicModel,
SingleLayerLinearModel,
LSTMwithHiddenDynamicModel,
SparseNNModel,
skip_if_no_torchvision,
)
from torch.quantization.quantization_mappings import (
get_default_static_quant_module_mappings,
get_default_dynamic_quant_module_mappings,
get_default_float_to_quantized_operator_mappings,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing._internal.common_quantized import override_qengines
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
import torch.quantization.fx.quantization_patterns as qp
from torch.quantization.ns.pattern_utils import (
get_base_name_to_sets_of_related_ops,
get_type_a_related_to_b,
)
from torch.quantization.ns.graph_matcher import (
get_matching_subgraph_pairs,
GraphMatchingException,
)
from torch.quantization._numeric_suite_fx import (
extract_weights,
_extract_weights_impl,
add_loggers,
_add_loggers_impl,
OutputLogger,
add_shadow_loggers,
_add_shadow_loggers_impl,
extract_logger_info,
extract_shadow_logger_info,
)
# Note: these models are not for use outside of this file. While it's good
# to reuse code, we also need to be able to iterate on tests
# quickly when debugging. If a test model has a large number of callsites
# across various different files, speed of debugging on individual test cases
# decreases.
class LinearReluFunctional(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.empty(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w1, self.b1)
x = F.relu(x)
return x
class LinearReluLinearFunctional(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.Tensor(4, 4))
self.b = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w, self.b)
x = F.relu(x)
x = F.linear(x, self.w, self.b)
return x
class AllConvAndLinearFusionModules(torch.nn.Module):
def __init__(self):
super().__init__()
# conv1d
self.conv1d_0 = nn.Conv1d(1, 1, 1)
# conv1d - relu
self.conv1d_1 = nn.Conv1d(1, 1, 1)
self.relu_0 = nn.ReLU()
# conv1d - bn (qat only)
self.conv1d_2 = nn.Conv1d(1, 1, 1)
self.bn1d_0 = nn.BatchNorm1d(1)
# conv1d - bn - relu (qat only)
self.conv1d_3 = nn.Conv1d(1, 1, 1)
self.bn1d_1 = nn.BatchNorm1d(1)
self.relu_4 = nn.ReLU()
# conv2d
self.conv2d_0 = nn.Conv2d(1, 1, 1)
# conv2d - relu
self.conv2d_1 = nn.Conv2d(1, 1, 1)
self.relu_1 = nn.ReLU()
# conv2d - bn (qat only)
self.conv2d_2 = nn.Conv2d(1, 1, 1)
self.bn2d_0 = nn.BatchNorm2d(1)
# conv2d - bn - relu (qat only)
self.conv2d_3 = nn.Conv2d(1, 1, 1)
self.bn2d_1 = nn.BatchNorm2d(1)
self.relu_5 = nn.ReLU()
# conv3d
self.conv3d_0 = nn.Conv3d(1, 1, 1)
# conv3d - relu
self.conv3d_1 = nn.Conv3d(1, 1, 1)
self.relu_2 = nn.ReLU()
# conv3d - bn (qat only)
self.conv3d_2 = nn.Conv3d(1, 1, 1)
self.bn3d_0 = nn.BatchNorm3d(1)
# conv3d - bn - relu (qat only)
self.conv3d_3 = nn.Conv3d(1, 1, 1)
self.bn3d_1 = nn.BatchNorm3d(1)
self.relu_6 = nn.ReLU()
# linear
self.linear_0 = nn.Linear(1, 1)
# linear - relu
self.linear_1 = nn.Linear(1, 1)
self.relu_3 = nn.ReLU()
def forward(self, x):
# conv1d
x = self.conv1d_0(x)
x = self.conv1d_1(x)
x = self.relu_0(x)
x = self.conv1d_2(x)
x = self.bn1d_0(x)
x = self.conv1d_3(x)
x = self.bn1d_1(x)
x = self.relu_4(x)
# conv2d
x = x.reshape(1, 1, 1, 1)
x = self.conv2d_0(x)
x = self.conv2d_1(x)
x = self.relu_1(x)
x = self.conv2d_2(x)
x = self.bn2d_0(x)
x = self.conv2d_3(x)
x = self.bn2d_1(x)
x = self.relu_5(x)
# conv3d
x = x.reshape(1, 1, 1, 1, 1)
x = self.conv3d_0(x)
x = self.conv3d_1(x)
x = self.relu_2(x)
x = self.conv3d_2(x)
x = self.bn3d_0(x)
x = self.conv3d_3(x)
x = self.bn3d_1(x)
x = self.relu_6(x)
# linear
x = x.reshape(1, 1)
x = self.linear_0(x)
x = self.linear_1(x)
x = self.relu_3(x)
return x
class AllConvFunctional(torch.nn.Module):
def __init__(self, weight1d, weight2d, weight3d, bias1d, bias2d, bias3d):
super().__init__()
self.weight1d = torch.nn.Parameter(weight1d)
self.weight2d = torch.nn.Parameter(weight2d)
self.weight3d = torch.nn.Parameter(weight3d)
self.bias1d = torch.nn.Parameter(bias1d)
self.bias2d = torch.nn.Parameter(bias2d)
self.bias3d = torch.nn.Parameter(bias3d)
self.stride1d = 1
self.padding1d = 0
self.dilation1d = 1
self.stride2d = (1, 1)
self.padding2d = (0, 0)
self.dilation2d = (1, 1)
self.groups = 1
self.stride3d = (1, 1, 1)
self.padding3d = (0, 0, 0)
self.dilation3d = (1, 1, 1)
def forward(self, x):
x = F.conv1d(
x, self.weight1d, self.bias1d, self.stride1d, self.padding1d,
self.dilation1d, self.groups)
x = F.conv1d(
x, self.weight1d, self.bias1d, self.stride1d, self.padding1d,
self.dilation1d, self.groups)
x = F.relu(x)
x = F.conv2d(
x, self.weight2d, self.bias2d, self.stride2d, self.padding2d,
self.dilation2d, self.groups)
x = F.conv2d(
x, self.weight2d, self.bias2d, self.stride2d, self.padding2d,
self.dilation2d, self.groups)
x = F.relu(x)
x = F.conv3d(
x, self.weight3d, self.bias3d, self.stride3d, self.padding3d,
self.dilation3d, self.groups)
x = F.conv3d(
x, self.weight3d, self.bias3d, self.stride3d, self.padding3d,
self.dilation3d, self.groups)
x = F.relu(x)
return x
class TestFXGraphMatcher(QuantizationTestCase):
@override_qengines
def test_simple_mod(self):
m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.Conv2d_0':
((nn.Conv2d, nn.Conv2d), (nnq.Conv2d, nnq.Conv2d)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_fun(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Parameter(torch.empty(1, 4))
self.b = nn.Parameter(torch.zeros(1))
torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))
def forward(self, x):
return F.linear(x, self.w, self.b)
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.functional.linear_0':
((F.linear, F.linear), (toq.linear, toq.linear))
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_fusion(self):
m = LinearReluFunctional().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.nn.functional.linear_0':
((F.linear, F.relu), (toq.linear_relu, toq.linear_relu)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_simple_mod_multi(self):
m = nn.Sequential(
nn.Sequential(
nn.Conv2d(1, 1, 1),
),
nn.Conv2d(1, 1, 1),
).eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
def test_simple_tensor_ops(self):
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
z = x + y
return z
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
def test_matching_failure_node_count(self):
# verify that matching graphs with matching node types but
# different counts of matchable nodes fails
m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval()
mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig})
mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig})
with self.assertRaises(GraphMatchingException) as ex:
results = get_matching_subgraph_pairs(mp1, mp2)
@override_qengines
def test_matching_failure_node_type(self):
# verify that matching graphs with non-matching node types fails
m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
m2 = nn.Sequential(nn.Linear(1, 1)).eval()
mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig})
mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig})
with self.assertRaises(GraphMatchingException) as ex:
results = get_matching_subgraph_pairs(mp1, mp2)
@override_qengines
def test_nodes_before_cat(self):
# verify that nodes before cat get matched
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0):
x1 = torch.add(x0, 1.0)
y1 = torch.add(x0, 1.0)
x2 = torch.cat([x1, y1])
return x2
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.cat_0': ((torch.cat, torch.cat), (toq.cat, toq.cat)),
'base_op_torch.add_0': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_1': ((torch.add, torch.add), (toq.add, toq.add)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@override_qengines
def test_dict_return_type(self):
# verify that we can traverse up nodes which return dictionaries
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x0):
x1 = torch.add(x0, 1.0)
y1 = torch.add(x0, 1.0)
z1 = torch.add(x0, 1.0)
a1 = {'x1': x1, 'y1': (y1,), 'z1': [{'key': (z1,)}]}
return a1
m = M().eval()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
expected_types = {
'base_op_torch.add_0': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_1': ((torch.add, torch.add), (toq.add, toq.add)),
'base_op_torch.add_2': ((torch.add, torch.add), (toq.add, toq.add)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
@skipIfNoFBGEMM
def test_nodes_with_equal_types_get_matched(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = torch.mul(x, x)
x = torch.sigmoid(x)
x = F.relu(x)
return x
m = M().eval()
# prevent conv2 from getting quantized, so we can test
# modules with equal types
qconfig_dict = {
'': torch.quantization.default_qconfig,
'module_name': [('conv2', None)],
}
mp = prepare_fx(m, qconfig_dict)
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
results = get_matching_subgraph_pairs(mp, mq)
# all of these should be matched
expected_types = {
'base_op_torch.nn.Conv2d_1':
((nn.Conv2d, nn.Conv2d), (nnq.Conv2d, nnq.Conv2d)),
'base_op_torch.nn.Conv2d_0':
((nn.Conv2d, nn.Conv2d), (nn.Conv2d, nn.Conv2d)),
'base_op_torch.mul_0': ((torch.mul, torch.mul), (toq.mul, toq.mul)),
'base_op_torch.relu_0': ((F.relu, F.relu), (F.relu, F.relu)),
'base_op_torch.sigmoid_0':
((torch.sigmoid, torch.sigmoid), (torch.sigmoid, torch.sigmoid)),
}
self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq)
def test_op_relationship_mapping(self):
"""
Tests that the mapping of op relationships is complete.
"""
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
type_a_related_to_b = \
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
# 1. check static quant module mappings
static_quant_mod_mappings = get_default_static_quant_module_mappings()
for fp32_type, int8_type in static_quant_mod_mappings.items():
# skip quants and dequants, for the purposes of Numerical Suite
types_to_skip = (
torch.quantization.QuantStub,
torch.quantization.DeQuantStub,
nnq.FloatFunctional,
)
if fp32_type in types_to_skip:
continue
# verify relatedness
in_type_a_related_to_b = \
(fp32_type, int8_type) in type_a_related_to_b
self.assertTrue(
in_type_a_related_to_b,
f"{fp32_type} and {int8_type} need a relationship mapping")
# 2. check static quant op mappings
static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings()
for fp32_type, int8_type in static_quant_fun_mappings.items():
# verify relatedness
in_type_a_related_to_b = \
(fp32_type, int8_type) in type_a_related_to_b
self.assertTrue(
in_type_a_related_to_b,
f"{fp32_type} and {int8_type} need a relationship mapping")
# 3. check dynamic quant mappings
dynamic_quant_mappings = get_default_dynamic_quant_module_mappings()
for fp32_type, int8_type in dynamic_quant_mappings.items():
# TODO(future PR): enable correct weight extraction for these
# and remove from this list.
types_to_skip = (
nn.GRUCell,
nn.GRU,
nn.LSTMCell,
nn.RNNCell,
)
if fp32_type in types_to_skip:
continue
# verify relatedness
in_type_a_related_to_b = \
(fp32_type, int8_type) in type_a_related_to_b
self.assertTrue(
in_type_a_related_to_b,
f"{fp32_type} and {int8_type} need a relationship mapping")
# 4. go through the ops mapped to each QuantizeHandler type, and verify
# correctness.
def _op_in_base_sets_of_related_ops(op):
for name, ops in base_name_to_sets_of_related_ops.items():
if op in ops:
return True
return False
default_quant_patterns = get_default_quant_patterns()
for pattern, qhandler_cls in default_quant_patterns.items():
base_op = None
if isinstance(pattern, tuple):
base_op = pattern[-1]
elif isinstance(pattern, str):
# TODO(future PR): add handling for these
continue
else:
base_op = pattern
qhandler_cls_all_ops_quantizeable = [
qp.CatQuantizeHandler,
qp.ConvReluQuantizeHandler,
qp.LinearReLUQuantizeHandler,
qp.BatchNormQuantizeHandler,
qp.EmbeddingQuantizeHandler,
qp.RNNDynamicQuantizeHandler,
qp.ELUQuantizeHandler,
]
qhandler_cls_quant_op_same_signature = [
qp.FixedQParamsOpQuantizeHandler,
qp.CopyNodeQuantizeHandler,
]
if qhandler_cls == qp.BinaryOpQuantizeHandler:
# these ops do not have quantized equivalents
ops_to_skip = [
torch.bmm,
torch.sum,
torch.div,
torch.sub,
operator.truediv,
operator.sub
]
if base_op in ops_to_skip:
continue
self.assertTrue(
_op_in_base_sets_of_related_ops(base_op),
f"{base_op} not in sets of related ops")
elif qhandler_cls == qp.RNNDynamicQuantizeHandler:
# TODO(future PR): add support for all classes in
# RNNDynamicQuantizeHandler
pass
elif qhandler_cls == qp.DefaultNodeQuantizeHandler:
ops_to_skip = [
torch.nn.SiLU,
torch.nn.functional.silu,
]
if base_op in ops_to_skip:
continue
self.assertTrue(
_op_in_base_sets_of_related_ops(base_op),
f"{base_op} not in sets of related ops")
elif qhandler_cls in qhandler_cls_quant_op_same_signature:
# these ops use the same op signature for fp32 and quantized
# tensors
pass
elif qhandler_cls in qhandler_cls_all_ops_quantizeable:
self.assertTrue(
_op_in_base_sets_of_related_ops(base_op),
f"{base_op} not in sets of related ops")
else:
raise AssertionError(
f"handing for {qhandler_cls} not implemented")
class TestFXGraphMatcherModels(QuantizationTestCase):
@override_qengines
@skip_if_no_torchvision
def test_mobilenet_v2(self):
# verify that mobilenetv2 graph is able to be matched
import torchvision
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float()
mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
@override_qengines
@skip_if_no_torchvision
def test_mobilenet_v2_qat(self):
# verify that mobilenetv2 graph is able to be matched
import torchvision
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float()
mp = prepare_qat_fx(m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')})
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# assume success if no exceptions
results = get_matching_subgraph_pairs(mp, mq)
class FXNumericSuiteQuantizationTestCase(QuantizationTestCase):
def _test_extract_weights(
self, m, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx
):
if qconfig_dict is None:
qconfig_dict = {'': torch.quantization.default_qconfig}
mp = prepare_fn(m, qconfig_dict)
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
# test both the public API as well as the internal GraphModule API
for extract_weights_fun in (extract_weights, _extract_weights_impl):
results = extract_weights_fun('fp32_prepared', mp, 'int8', mq)
self.assertTrue(
len(results) == results_len,
f"expected len {results_len}, got len {len(results)}")
self.assert_ns_compare_dict_valid(results)
def _test_match_activations(
self, m, data, prepared_expected_node_occurrence=None, results_len=0,
should_log_inputs=False,
qconfig_dict=None,
skip_scripting=False,
prepare_fn=prepare_fx,
):
if qconfig_dict is None:
qconfig_dict = {'': torch.quantization.default_qconfig}
if prepare_fn == prepare_fx:
m.eval()
else:
m.train()
mp = prepare_fn(m, qconfig_dict)
mp(*data)
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_ns, mq_ns = add_loggers(
'fp32_prepared', mp, 'int8', mq, OutputLogger,
should_log_inputs=should_log_inputs)
if prepared_expected_node_occurrence:
self.checkGraphModuleNodes(
mp_ns, expected_node_occurrence=prepared_expected_node_occurrence)
self.checkGraphModuleNodes(
mq_ns, expected_node_occurrence=prepared_expected_node_occurrence)
if not skip_scripting:
mp_ns = torch.jit.script(mp_ns)
mq_ns = torch.jit.script(mq_ns)
# calibrate
mp_ns(*data)
mq_ns(*data)
# check activation result correctness
act_compare_dict = extract_logger_info(mp_ns, mq_ns, OutputLogger)
self.assertTrue(
len(act_compare_dict) == results_len,
f"expected len {results_len}, got len {len(act_compare_dict)}")
self.assert_ns_compare_dict_valid(act_compare_dict)
return act_compare_dict
def _test_match_shadow_activations(
self, m, data, prepared_expected_node_occurrence=None, results_len=0,
should_log_inputs=False, qconfig_dict=None, skip_scripting=False,
prepare_fn=prepare_fx,
):
if qconfig_dict is None:
qconfig_dict = {'': torch.quantization.default_qconfig}
if prepare_fn == prepare_fx:
m.eval()
else:
m.train()
mp = prepare_fn(m, qconfig_dict)
mp(*data)
# TODO(future PR): prevent the need for copying here, we can copy the
# modules but should reuse the underlying tensors
mp_copy = copy.deepcopy(mp)
mq = convert_fx(mp_copy)
mp_shadows_mq = add_shadow_loggers(
'fp32_prepared', mp, 'int8', mq, OutputLogger,
should_log_inputs=should_log_inputs)
if prepared_expected_node_occurrence:
self.checkGraphModuleNodes(
mp_shadows_mq, expected_node_occurrence=prepared_expected_node_occurrence)
if not skip_scripting:
mp_shadows_mq = torch.jit.script(mp_shadows_mq)
# calibrate
mp_shadows_mq(*data)
# check activation result correctness
act_compare_dict = extract_shadow_logger_info(
mp_shadows_mq, OutputLogger)
self.assertTrue(
len(act_compare_dict) == results_len,
f"expected len {results_len}, got len {len(act_compare_dict)}")
self.assert_ns_compare_dict_valid(act_compare_dict)
return act_compare_dict
class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
@skipIfNoFBGEMM
def test_extract_weights_mod_ptq(self):
m = AllConvAndLinearFusionModules().eval()
self._test_extract_weights(m, results_len=14)
@skipIfNoFBGEMM
def test_extract_weights_mod_qat(self):
m = AllConvAndLinearFusionModules().train()
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
self._test_extract_weights(
m, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx)
@skipIfNoFBGEMM
def test_extract_weights_linear_fun_ptq(self):
m = LinearReluLinearFunctional().eval()
self._test_extract_weights(m, results_len=2)
@skipIfNoFBGEMM
def test_extract_weights_linear_fun_qat(self):
m = LinearReluLinearFunctional().train()
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
self._test_extract_weights(
m, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx)
@skipIfNoFBGEMM
def test_extract_weights_conv_fun_ptq(self):
w1d = torch.randn(1, 1, 1)
w2d = torch.randn(1, 1, 1, 1)
w3d = torch.randn(1, 1, 1, 1, 1)
b1d = torch.randn(1)
b2d = torch.randn(1)
b3d = torch.randn(1)
m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).eval()
self._test_extract_weights(m, results_len=6)
@skipIfNoFBGEMM
def test_extract_weights_conv_fun_qat(self):
w1d = torch.randn(1, 1, 1)
w2d = torch.randn(1, 1, 1, 1)
w3d = torch.randn(1, 1, 1, 1, 1)
b1d = torch.randn(1)
b2d = torch.randn(1)
b3d = torch.randn(1)
m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).train()
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
self._test_extract_weights(
m, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx)
@skipIfNoFBGEMM
def test_extract_weights_dynamic(self):
# TODO(future PR): add Linear-ReLU, after #55393 is fixed.
m = nn.Sequential(nn.Linear(1, 1)).eval()
qconfig_dict = {
'object_type': [
(nn.Linear, default_dynamic_qconfig),
],
}
self._test_extract_weights(m, results_len=1, qconfig_dict=qconfig_dict)
def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx):
m = nn.Sequential(
torch.quantization.QuantStub(),
nn.Conv2d(1, 1, 1),
nn.Conv2d(1, 1, 1),
).eval()
qconfig_dict = None
if prepare_fn == prepare_qat_fx:
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
expected_occurrence = {
ns.call_module(OutputLogger): 2,
}
self._test_match_activations(
m, (torch.randn(2, 1, 2, 2),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_fn)
@skipIfNoFBGEMM
def test_match_activations_mod_ptq(self):
self._test_match_activations_mod_impl(prepare_fn=prepare_fx)
@skipIfNoFBGEMM
def test_match_activations_mod_qat(self):
self._test_match_activations_mod_impl(prepare_fn=prepare_qat_fx)
def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx):
m = LinearReluLinearFunctional().eval()
qconfig_dict = None
if prepare_fn == prepare_qat_fx:
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
expected_occurrence = {
ns.call_module(OutputLogger): 2,
}
self._test_match_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=2, prepare_fn=prepare_fn, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_match_activations_fun_ptq(self):
self._test_match_activations_fun_impl(prepare_fn=prepare_fx)
@skipIfNoFBGEMM
def test_match_activations_fun_qat(self):
self._test_match_activations_fun_impl(prepare_fn=prepare_qat_fx)
def _test_add_shadow_loggers_mod_impl(self, prepare_fn=prepare_fx):
m = nn.Sequential(
nn.Conv2d(1, 1, 1),
nn.Conv2d(1, 1, 1),
).eval()
qconfig_dict = None
if prepare_fn == prepare_qat_fx:
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
res = self._test_match_shadow_activations(
m, (torch.randn(1, 1, 4, 4),), results_len=2,
prepare_fn=prepare_fn, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_add_shadow_loggers_mod_ptq(self):
self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_fx)
@skipIfNoFBGEMM
def test_add_shadow_loggers_mod_qat(self):
self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_qat_fx)
def _test_add_shadow_loggers_fun_impl(self, prepare_fn=prepare_fx):
m = LinearReluLinearFunctional()
qconfig_dict = None
if prepare_fn == prepare_qat_fx:
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
res = self._test_match_shadow_activations(
m, (torch.randn(4, 4),), results_len=2, prepare_fn=prepare_fn,
qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_add_shadow_loggers_fun_ptq(self):
self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_fx)
@skipIfNoFBGEMM
def test_add_shadow_loggers_fun_qat(self):
self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_qat_fx)
@skipIfNoFBGEMM
def test_add_shadow_loggers_multiple_dtype_casts(self):
"""
Verifies that for nodes where the first input arg is a list,
such as `cat`, we insert an individual dtype cast for each
arg of the list.
"""
class M(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
x = torch.cat([x, x, x], dim=0)
return x
m = M().eval()
expected_occurrence = {
# 3 dequantize function calls from the 3 dtype casts for [x, x, x]
ns.call_function(torch.dequantize): 3,
# 1 dequantize method call for module output
ns.call_method("dequantize"): 1,
}
self._test_match_shadow_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=1)
@skipIfNoFBGEMM
def test_logging_inputs(self):
"""
Verifies that logging inputs works correctly
"""
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
x = torch.cat([x, x], dim=0)
return x
m = M().eval()
self._test_match_shadow_activations(
m, (torch.randn(1, 1, 4, 4),),
results_len=2,
should_log_inputs=True)
@skipIfNoFBGEMM
def test_ops_with_same_fp32_and_int8_signature(self):
"""
Verifies that we can match pairs of ops which have the same aten
signature for fp32 and int8 tensors.
"""
class M(nn.Module):
def __init__(self):
super().__init__()
self.max_pool_2d = nn.MaxPool2d(2)
def forward(self, x):
x = self.max_pool_2d(x)
x = F.relu(x)
return x
m = M().eval()
self._test_match_activations(
m, (torch.randn(1, 1, 2, 2),),
results_len=2)
@skipIfNoFBGEMM
def test_linear_fp16_weights(self):
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
m = LinearReluFunctional().eval()
self._test_extract_weights(m, results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_linear_fp16_activations(self):
for should_log_inputs in (True, False):
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
m = LinearReluFunctional().eval()
num_loggers = 2 if should_log_inputs else 1
expected_occurrence = {
ns.call_module(OutputLogger): num_loggers,
}
res = self._test_match_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=1,
qconfig_dict=qconfig_dict,
should_log_inputs=should_log_inputs)
@skipIfNoFBGEMM
def test_linear_fp16_shadow_activations(self):
for should_log_inputs in (True, False):
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
m = LinearReluFunctional().eval()
num_loggers = 4 if should_log_inputs else 2
expected_occurrence = {
ns.call_module(OutputLogger): num_loggers,
}
res2 = self._test_match_shadow_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=1,
qconfig_dict=qconfig_dict,
should_log_inputs=should_log_inputs)
@skipIfNoFBGEMM
def test_op_with_either_fp32_or_int8_input(self):
"""
Verify that shadowing works with ops which accept either fp32 or
int8 inputs.
"""
class M(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(x)
x = F.relu(x)
return x
m = M()
res = self._test_match_shadow_activations(
m, (torch.randn(4, 4),),
results_len=2)
@skipIfNoFBGEMM
def test_user_module(self):
"""
For user defined modules,
1. weight extraction should not crash
2. unshadowed activations should have loggers, loggers will only log if
the output dtype is in the allowlist
3. shadowed activations should not have loggers
(since I/O dtype is unknown)
"""
class UserModule(nn.Module):
def forward(self, x):
return x
class M(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
self.user_module = UserModule()
def forward(self, x):
x = self.linear(x)
x = self.user_module(x)
return x
m = M().eval()
# quantize without tracing through UserModule
qconfig_dict = {'': torch.quantization.default_qconfig}
prepare_custom_config_dict = {'non_traceable_module_name': ['user_module']}
mp = prepare_fx(m, qconfig_dict, prepare_custom_config_dict)
mp(torch.randn(1, 1, 1))
mq = convert_fx(copy.deepcopy(mp))
# weight extraction should not crash
weights = _extract_weights_impl('fp32_prepared', mp, 'int8', mq)
# unshadowed activations should have loggers
# add loggers, without retracing
# note: converting again because we cannot copy a quantized linear
mp_ns, mq_ns = _add_loggers_impl(
'fp32_prepared', copy.deepcopy(mp), 'int8',
convert_fx(copy.deepcopy(mp)), OutputLogger,
should_log_inputs=True)
# both fp32 and int8 models should have 4 loggers each, 2 for I/O
# of linear, and 2 for I/O of user_module
unshadowed_expected_occurrence = {
ns.call_module(OutputLogger): 4,
}
self.checkGraphModuleNodes(
mp_ns, expected_node_occurrence=unshadowed_expected_occurrence)
self.checkGraphModuleNodes(
mq_ns, expected_node_occurrence=unshadowed_expected_occurrence)
# shadowed activations should only have loggers for nodes where
# the types are known and we can do a dtype cast
# add shadow loggers, without retracing
mp_shadows_mq_ns = _add_shadow_loggers_impl(
'fp32_prepared', mp, 'int8', mq, OutputLogger,
should_log_inputs=True)
# 2 loggers for I/O of linear, 0 loggers for I/O of user_module
shadowed_expected_occurrence = {
ns.call_module(OutputLogger): 2,
}
self.checkGraphModuleNodes(
mp_shadows_mq_ns, expected_node_occurrence=unshadowed_expected_occurrence)
class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
"""
Tests numeric suite core APIs on non-toy models.
"""
@skipIfNoFBGEMM
def test_compare_weights_conv(self):
test_cases = (
(ConvModel(),),
(ConvBnModel(),),
(ConvBnReLUModel(),),
)
for m, in test_cases:
m.eval()
self._test_extract_weights(m, results_len=1)
@skipIfNoFBGEMM
def test_compare_weights_linear(self):
test_cases = (
(SingleLayerLinearModel(), None),
(
SingleLayerLinearDynamicModel(),
{"object_type": [(nn.Linear, default_dynamic_qconfig)]},
),
)
for m, qconfig_dict in test_cases:
m.eval()
res = self._test_extract_weights(
m, results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_weights_lstm_dynamic(self):
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
m = LSTMwithHiddenDynamicModel().eval()
res = self._test_extract_weights(
m, results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_activations_conv(self):
test_cases = (
(ConvModel(),),
(ConvBnModel(),),
(ConvBnReLUModel(),),
)
for m, in test_cases:
m.eval()
res = self._test_match_activations(
m, (torch.randn(1, 3, 4, 4),), results_len=1)
@skipIfNoFBGEMM
def test_compare_activations_linear(self):
test_cases = (
(SingleLayerLinearModel(), None),
(
SingleLayerLinearDynamicModel(),
{"object_type": [(nn.Linear, default_dynamic_qconfig)]},
),
)
for m, qconfig_dict in test_cases:
m.eval()
res = self._test_match_activations(
m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_activations_lstm_dynamic(self):
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
m = LSTMwithHiddenDynamicModel().eval()
lstm_input = torch.rand((1, 1, 2))
lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
# TODO(future PR): enable scripting (quant prepared LSTM not scriptable)
res = self._test_match_activations(
m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict,
skip_scripting=True)
@skipIfNoFBGEMM
def test_compare_shadow_activations_conv(self):
test_cases = (
(ConvModel(),),
(ConvBnModel(),),
(ConvBnReLUModel(),),
)
for m, in test_cases:
m.eval()
res = self._test_match_shadow_activations(
m, (torch.randn(1, 3, 4, 4),), results_len=1)
@skipIfNoFBGEMM
def test_compare_shadow_activations_linear(self):
test_cases = (
(SingleLayerLinearModel(), None),
(
SingleLayerLinearDynamicModel(),
{"object_type": [(nn.Linear, default_dynamic_qconfig)]},
),
)
for m, qconfig_dict in test_cases:
m.eval()
res = self._test_match_shadow_activations(
m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict)
@skipIfNoFBGEMM
def test_compare_shadow_activations_lstm_dynamic(self):
qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]}
m = LSTMwithHiddenDynamicModel().eval()
lstm_input = torch.rand((1, 1, 2))
lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
# TODO(future PR): enable scripting (quant prepared LSTM not scriptable)
res = self._test_match_shadow_activations(
m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict,
skip_scripting=True)
@skipIfNoFBGEMM
def test_sparsenn_compare_activations(self):
for should_log_inputs in (True, False):
sparse_nn = SparseNNModel().eval()
idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.LongTensor([0, 4])
x = torch.randn(2, 4)
self._test_match_activations(
sparse_nn, (idx, offsets, x),
results_len=5,
should_log_inputs=should_log_inputs)
@skipIfNoFBGEMM
def test_sparsenn_shadow(self):
for should_log_inputs in (True, False):
sparse_nn = SparseNNModel().eval()
idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
offsets = torch.LongTensor([0, 4])
x = torch.randn(2, 4)
self._test_match_shadow_activations(
sparse_nn, (idx, offsets, x),
results_len=4,
should_log_inputs=should_log_inputs)