compare_model_stub_fx API implementation (#48951)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48951
compare_model_stub_fx API implementation
ghstack-source-id: 120817825
Test Plan:
buck test mode/dev caffe2/test:quantization_fx -- 'test_compare_model_stub_conv_static_fx'
buck test mode/dev caffe2/test:quantization_fx -- 'test_compare_model_stub_linear_static_fx'
Reviewed By: vkuzo
Differential Revision: D25379000
fbshipit-source-id: f1321d37b60b56b202e7d227e370ce13addb10cc
diff --git a/test/quantization/test_numeric_suite_fx.py b/test/quantization/test_numeric_suite_fx.py
index 0955b4b..ab580cd 100644
--- a/test/quantization/test_numeric_suite_fx.py
+++ b/test/quantization/test_numeric_suite_fx.py
@@ -1,26 +1,30 @@
import copy
import torch
+import torch.nn as nn
+import torch.nn.intrinsic as nni
from torch.quantization import get_default_qconfig
from torch.quantization._numeric_suite_fx import (
compare_weights_fx,
remove_qconfig_observer_fx,
+ compare_model_stub_fx,
)
from torch.quantization.fx.quantize import is_activation_post_process
from torch.quantization.quantize_fx import convert_fx, fuse_fx, prepare_fx
from torch.testing._internal.common_quantization import (
ConvBnModel,
- ConvBNReLU,
+ ConvBnReLUModel,
ConvModel,
QuantizationTestCase,
SingleLayerLinearDynamicModel,
SingleLayerLinearModel,
- skipIfNoFBGEMM,
+ test_only_eval_fn,
)
+from torch.testing._internal.common_quantized import override_qengines
-@skipIfNoFBGEMM
class TestGraphModeNumericSuite(QuantizationTestCase):
+ @override_qengines
def test_remove_qconfig_observer_fx(self):
r"""Remove activation_post_process node from fx prepred model"""
float_model = SingleLayerLinearModel()
@@ -43,29 +47,29 @@
if node.op == "call_module":
self.assertFalse(is_activation_post_process(modules[node.target]))
- @skipIfNoFBGEMM
+
+
+ def compare_and_validate_model_weights_results_fx(
+ self, float_model, q_model, expected_weight_dict_keys
+ ):
+ weight_dict = compare_weights_fx(float_model.state_dict(), q_model.state_dict())
+
+ self.assertTrue(weight_dict.keys() == expected_weight_dict_keys)
+ self.assertEqual(len(weight_dict), 1)
+
+ for k, v in weight_dict.items():
+ self.assertTrue(v["float"].shape == v["quantized"].shape)
+
+
+ @override_qengines
def test_compare_weights_conv_static_fx(self):
r"""Compare the weights of float and static quantized conv layer"""
- def calibrate(model, calib_data):
- model.eval()
- with torch.no_grad():
- for inp in calib_data:
- model(*inp)
-
- def compare_and_validate_results(float_model, q_model):
- weight_dict = compare_weights_fx(
- float_model.state_dict(), q_model.state_dict()
- )
- self.assertEqual(len(weight_dict), 1)
- for k, v in weight_dict.items():
- self.assertTrue(v["float"].shape == v["quantized"].shape)
-
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
- model_list = [ConvModel(), ConvBnModel(), ConvBNReLU()]
+ model_list = [ConvModel(), ConvBnModel(), ConvBnReLUModel()]
for float_model in model_list:
float_model.eval()
@@ -73,67 +77,49 @@
prepared_model = prepare_fx(float_model, qconfig_dict)
# Run calibration
- calibrate(prepared_model, self.img_data_2d)
+ test_only_eval_fn(prepared_model, self.img_data_2d)
q_model = convert_fx(prepared_model)
- compare_and_validate_results(fused, q_model)
+ expected_weight_dict_keys = {"conv.weight"}
+ self.compare_and_validate_model_weights_results_fx(
+ fused, q_model, expected_weight_dict_keys
+ )
- @skipIfNoFBGEMM
+ @override_qengines
def test_compare_weights_linear_static_fx(self):
r"""Compare the weights of float and static quantized linear layer"""
- def calibrate(model, calib_data):
- model.eval()
- with torch.no_grad():
- for inp in calib_data:
- model(*inp)
-
- def compare_and_validate_results(float_model, q_model):
- weight_dict = compare_weights_fx(
- float_model.state_dict(), q_model.state_dict()
- )
- self.assertEqual(len(weight_dict), 1)
- for k, v in weight_dict.items():
- self.assertTrue(v["float"].shape == v["quantized"].shape)
-
- float_model = SingleLayerLinearModel()
- float_model.eval()
-
qengine = torch.backends.quantized.engine
qconfig = get_default_qconfig(qengine)
qconfig_dict = {"": qconfig}
+ float_model = SingleLayerLinearModel()
+ float_model.eval()
+
prepared_model = prepare_fx(float_model, qconfig_dict)
backup_prepared_model = copy.deepcopy(prepared_model)
backup_prepared_model.eval()
# Run calibration
- calibrate(prepared_model, self.calib_data)
+ test_only_eval_fn(prepared_model, self.calib_data)
q_model = convert_fx(prepared_model)
- compare_and_validate_results(backup_prepared_model, q_model)
+ expected_weight_dict_keys = {"fc1._packed_params._packed_params"}
+ self.compare_and_validate_model_weights_results_fx(
+ backup_prepared_model, q_model, expected_weight_dict_keys
+ )
- @skipIfNoFBGEMM
+ @override_qengines
def test_compare_weights_linear_dynamic_fx(self):
r"""Compare the weights of float and dynamic quantized linear layer"""
- def compare_and_validate_results(float_model, q_model):
- weight_dict = compare_weights_fx(
- float_model.state_dict(), q_model.state_dict()
- )
- self.assertEqual(len(weight_dict), 1)
- for k, v in weight_dict.items():
- self.assertTrue(len(v["float"]) == len(v["quantized"]))
- for i, val in enumerate(v["quantized"]):
- self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
+ qconfig = torch.quantization.qconfig.default_dynamic_qconfig
+ qconfig_dict = {"": qconfig}
float_model = SingleLayerLinearDynamicModel()
float_model.eval()
- qconfig = torch.quantization.qconfig.default_dynamic_qconfig
- qconfig_dict = {"": qconfig}
-
prepared_model = prepare_fx(float_model, qconfig_dict)
backup_prepared_model = copy.deepcopy(prepared_model)
@@ -141,4 +127,84 @@
q_model = convert_fx(prepared_model)
- compare_and_validate_results(backup_prepared_model, q_model)
+ expected_weight_dict_keys = {"fc1._packed_params._packed_params"}
+ self.compare_and_validate_model_weights_results_fx(
+ backup_prepared_model, q_model, expected_weight_dict_keys
+ )
+
+ def compare_and_validate_model_stub_results_fx(
+ self, float_model, q_model, module_swap_list, data, expected_ob_dict_keys
+ ):
+ ob_dict = compare_model_stub_fx(float_model, q_model, module_swap_list, data)
+
+ self.assertTrue(ob_dict.keys() == expected_ob_dict_keys)
+ self.assertEqual(len(ob_dict), 1)
+
+ for k, v in ob_dict.items():
+ self.assertTrue(len(v["float"]) == len(v["quantized"]))
+ for i, val in enumerate(v["quantized"]):
+ self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
+
+
+ @override_qengines
+ def test_compare_model_stub_conv_static_fx(self):
+ r"""Compare the output of static quantized conv layer and its float shadow module"""
+
+ qengine = torch.backends.quantized.engine
+ qconfig = get_default_qconfig(qengine)
+ qconfig_dict = {"": qconfig}
+
+ model_list = [ConvModel(), ConvBnReLUModel()]
+
+ for float_model in model_list:
+ float_model.eval()
+
+ prepared_model = prepare_fx(float_model, qconfig_dict)
+
+ backup_prepared_model = copy.deepcopy(prepared_model)
+
+ # Run calibration
+ test_only_eval_fn(prepared_model, self.img_data_2d)
+ q_model = convert_fx(prepared_model)
+
+ module_swap_list = [nn.Conv2d, nni.modules.fused.ConvReLU2d]
+
+ expected_ob_dict_keys = {"conv.stats"}
+ self.compare_and_validate_model_stub_results_fx(
+ backup_prepared_model,
+ q_model,
+ module_swap_list,
+ self.img_data_2d[0][0],
+ expected_ob_dict_keys,
+ )
+
+ @override_qengines
+ def test_compare_model_stub_linear_static_fx(self):
+ r"""Compare the output of static quantized linear layer and its float shadow module"""
+
+ qengine = torch.backends.quantized.engine
+ qconfig = get_default_qconfig(qengine)
+ qconfig_dict = {"": qconfig}
+
+ float_model = SingleLayerLinearModel()
+ float_model.eval()
+
+ prepared_model = prepare_fx(float_model, qconfig_dict)
+
+ backup_prepared_model = copy.deepcopy(prepared_model)
+
+ # Run calibration
+ test_only_eval_fn(prepared_model, self.calib_data)
+ q_model = convert_fx(prepared_model)
+
+ linear_data = self.calib_data[0][0]
+ module_swap_list = [nn.Linear]
+
+ expected_ob_dict_keys = {"fc1.stats"}
+ self.compare_and_validate_model_stub_results_fx(
+ backup_prepared_model,
+ q_model,
+ module_swap_list,
+ linear_data,
+ expected_ob_dict_keys,
+ )
diff --git a/torch/quantization/_numeric_suite_fx.py b/torch/quantization/_numeric_suite_fx.py
index aeba95b..cfb806e 100644
--- a/torch/quantization/_numeric_suite_fx.py
+++ b/torch/quantization/_numeric_suite_fx.py
@@ -7,6 +7,12 @@
from torch.fx import GraphModule # type: ignore
from torch.fx import map_arg # type: ignore
from torch.fx.graph import Graph
+from torch.quantization._numeric_suite import (
+ get_logger_dict,
+ prepare_model_with_stubs,
+ compare_weights,
+ ShadowLogger,
+)
from torch.quantization.fx.quantize import _remove_qconfig, is_activation_post_process
@@ -45,34 +51,6 @@
return model
-def _find_match(str_list, key_str, postfix):
- split_str = key_str.split(".")
- if split_str[-1] == postfix:
- match_string = "".join(key_str.split(".")[0:-1])
- for s2 in str_list:
- pattern1 = "".join(s2.split(".")[0:-1])
- pattern2 = "".join(s2.split(".")[0:-2])
- if match_string == pattern1:
- return s2
- if match_string == pattern2:
- return s2
-
- # For matching "fc.weight" and "fc._packed_params._packed_params"
- if postfix == "_packed_params":
- match_string = "".join(key_str.split(".")[0:-2])
- if len(match_string) == 0:
- return None
- for s2 in str_list:
- pattern1 = "".join(s2.split(".")[0:-1])
- pattern2 = "".join(s2.split(".")[0:-2])
- if match_string == pattern1:
- return s2
- if match_string == pattern2:
- return s2
- else:
- return None
-
-
def compare_weights_fx(float_dict, quantized_dict):
r"""Compare the weights of the float module with its corresponding quantized
module. Return a dict with key corresponding to module names and each entry being
@@ -86,7 +64,7 @@
quantized_model = convert_fx(prepared_model)
qmodel = quantized_model
- wt_compare_dict = compare_weights(backup_prepared_model.state_dict(), qmodel.state_dict())
+ wt_compare_dict = compare_weights_fx(backup_prepared_model.state_dict(), qmodel.state_dict())
for key in wt_compare_dict:
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
@@ -102,38 +80,69 @@
torch._C._log_api_usage_once(
"quantization_api._numeric_suite_fx.compare_weights_fx"
)
- weight_dict: Dict[str, Dict] = {}
- for key in quantized_dict:
- match_key = _find_match(float_dict, key, "weight")
- if match_key is not None:
- weight_dict[key] = {}
- weight_dict[key]["float"] = float_dict[match_key]
- weight_dict[key]["quantized"] = quantized_dict[key]
- continue
+ return compare_weights(float_dict, quantized_dict)
- # For matching "fc.weight" and "fc._packed_params._packed_params"
- match_key = _find_match(float_dict, key, "_packed_params")
- if match_key is not None:
- weight_dict[key] = {}
- weight_dict[key]["float"] = float_dict[match_key]
- weight_dict[key]["quantized"] = quantized_dict[key][0]
- # For LSTM
- split_str = key.split(".")
- if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
- layer = split_str[-2]
- module_name = ".".join(split_str[:-3])
- float_weight_ih_key = module_name + ".weight_ih_l" + layer
- float_weight_hh_key = module_name + ".weight_hh_l" + layer
- if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
- weight_dict[key] = {}
- weight_dict[key]["float"] = float_dict[float_weight_ih_key]
- weight_dict[key]["quantized"] = (
- quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
- )
- weight_dict[key]["float"] = float_dict[float_weight_hh_key]
- weight_dict[key]["quantized"] = (
- quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
- )
+def prepare_model_with_stubs_fx(float_module, q_module, module_swap_list, Logger):
+ r"""Prepare the model by attaching the float module to its matching quantized
+ module as the shadow if the float module type is in module_swap_list.
- return weight_dict
+ Example usage:
+ prepare_model_with_stubs_fx(float_model, q_model, module_swap_list, Logger)
+ q_model(data)
+ ob_dict = get_logger_dict(q_model)
+
+ Args:
+ float_module: float module used to generate the q_module
+ q_module: module quantized from float_module
+ module_swap_list: list of float module types to attach the shadow
+ Logger: type of logger to be used in shadow module to process the outputs of
+ quantized module and its float shadow module
+ """
+ torch._C._log_api_usage_once(
+ "quantization_api._numeric_suite.prepare_model_with_stubs_fx"
+ )
+ return prepare_model_with_stubs(float_module, q_module, module_swap_list, Logger)
+
+
+def compare_model_stub_fx(
+ float_model, q_model, module_swap_list, *data, Logger=ShadowLogger
+):
+ r"""Compare quantized module in a model with its floating point counterpart,
+ feeding both of them the same input. Return a dict with key corresponding to
+ module names and each entry being a dictionary with two keys 'float' and
+ 'quantized', containing the output tensors of quantized and its matching
+ float shadow module. This dict can be used to compare and compute the module
+ level quantization error.
+
+ This function first call prepare_model_with_stubs_fx() to swap the quantized
+ module that we want to compare with the Shadow module, which takes quantized
+ module, corresponding float module and logger as input, and creates a forward
+ path inside to make the float module to shadow quantized module sharing the
+ same input. The logger can be customizable, default logger is ShadowLogger
+ and it will save the outputs of the quantized module and float module that
+ can be used to compute the module level quantization error.
+
+ Example usage:
+ module_swap_list = [nn.Linear]
+ ob_dict = compare_model_stub_fx(float_model,qmodel,module_swap_list, data)
+ for key in ob_dict:
+ print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
+
+ Args:
+ float_model: float model used to generate the q_model
+ q_model: model quantized from float_model
+ module_swap_list: list of float module types at which shadow modules will
+ be attached.
+ data: input data used to run the prepared q_model
+ Logger: type of logger to be used in shadow module to process the outputs of
+ quantized module and its float shadow module
+ """
+ torch._C._log_api_usage_once(
+ "quantization_api._numeric_suite.compare_model_stub_fx"
+ )
+ float_model = remove_qconfig_observer_fx(float_model)
+ prepare_model_with_stubs_fx(float_model, q_model, module_swap_list, Logger)
+ q_model(*data)
+ ob_dict = get_logger_dict(q_model)
+ return ob_dict
diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py
index 64c9c91..459bc4b 100644
--- a/torch/testing/_internal/common_quantization.py
+++ b/torch/testing/_internal/common_quantization.py
@@ -910,6 +910,19 @@
x = self.dequant(x)
return x
+class ConvBnReLUModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+ self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
class AnnotatedConvBnReLUModel(torch.nn.Module):
def __init__(self, qengine='fbgemm'):
super(AnnotatedConvBnReLUModel, self).__init__()