compare_model_stub_fx API implementation (#48951)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48951

compare_model_stub_fx API implementation
ghstack-source-id: 120817825

Test Plan:
buck test mode/dev caffe2/test:quantization_fx -- 'test_compare_model_stub_conv_static_fx'
buck test mode/dev caffe2/test:quantization_fx -- 'test_compare_model_stub_linear_static_fx'

Reviewed By: vkuzo

Differential Revision: D25379000

fbshipit-source-id: f1321d37b60b56b202e7d227e370ce13addb10cc
diff --git a/test/quantization/test_numeric_suite_fx.py b/test/quantization/test_numeric_suite_fx.py
index 0955b4b..ab580cd 100644
--- a/test/quantization/test_numeric_suite_fx.py
+++ b/test/quantization/test_numeric_suite_fx.py
@@ -1,26 +1,30 @@
 import copy
 
 import torch
+import torch.nn as nn
+import torch.nn.intrinsic as nni
 from torch.quantization import get_default_qconfig
 from torch.quantization._numeric_suite_fx import (
     compare_weights_fx,
     remove_qconfig_observer_fx,
+    compare_model_stub_fx,
 )
 from torch.quantization.fx.quantize import is_activation_post_process
 from torch.quantization.quantize_fx import convert_fx, fuse_fx, prepare_fx
 from torch.testing._internal.common_quantization import (
     ConvBnModel,
-    ConvBNReLU,
+    ConvBnReLUModel,
     ConvModel,
     QuantizationTestCase,
     SingleLayerLinearDynamicModel,
     SingleLayerLinearModel,
-    skipIfNoFBGEMM,
+    test_only_eval_fn,
 )
+from torch.testing._internal.common_quantized import override_qengines
 
 
-@skipIfNoFBGEMM
 class TestGraphModeNumericSuite(QuantizationTestCase):
+    @override_qengines
     def test_remove_qconfig_observer_fx(self):
         r"""Remove activation_post_process node from fx prepred model"""
         float_model = SingleLayerLinearModel()
@@ -43,29 +47,29 @@
             if node.op == "call_module":
                 self.assertFalse(is_activation_post_process(modules[node.target]))
 
-    @skipIfNoFBGEMM
+
+
+    def compare_and_validate_model_weights_results_fx(
+        self, float_model, q_model, expected_weight_dict_keys
+    ):
+        weight_dict = compare_weights_fx(float_model.state_dict(), q_model.state_dict())
+
+        self.assertTrue(weight_dict.keys() == expected_weight_dict_keys)
+        self.assertEqual(len(weight_dict), 1)
+
+        for k, v in weight_dict.items():
+            self.assertTrue(v["float"].shape == v["quantized"].shape)
+
+
+    @override_qengines
     def test_compare_weights_conv_static_fx(self):
         r"""Compare the weights of float and static quantized conv layer"""
 
-        def calibrate(model, calib_data):
-            model.eval()
-            with torch.no_grad():
-                for inp in calib_data:
-                    model(*inp)
-
-        def compare_and_validate_results(float_model, q_model):
-            weight_dict = compare_weights_fx(
-                float_model.state_dict(), q_model.state_dict()
-            )
-            self.assertEqual(len(weight_dict), 1)
-            for k, v in weight_dict.items():
-                self.assertTrue(v["float"].shape == v["quantized"].shape)
-
         qengine = torch.backends.quantized.engine
         qconfig = get_default_qconfig(qengine)
         qconfig_dict = {"": qconfig}
 
-        model_list = [ConvModel(), ConvBnModel(), ConvBNReLU()]
+        model_list = [ConvModel(), ConvBnModel(), ConvBnReLUModel()]
         for float_model in model_list:
             float_model.eval()
 
@@ -73,67 +77,49 @@
             prepared_model = prepare_fx(float_model, qconfig_dict)
 
             # Run calibration
-            calibrate(prepared_model, self.img_data_2d)
+            test_only_eval_fn(prepared_model, self.img_data_2d)
             q_model = convert_fx(prepared_model)
 
-            compare_and_validate_results(fused, q_model)
+            expected_weight_dict_keys = {"conv.weight"}
+            self.compare_and_validate_model_weights_results_fx(
+                fused, q_model, expected_weight_dict_keys
+            )
 
-    @skipIfNoFBGEMM
+    @override_qengines
     def test_compare_weights_linear_static_fx(self):
         r"""Compare the weights of float and static quantized linear layer"""
 
-        def calibrate(model, calib_data):
-            model.eval()
-            with torch.no_grad():
-                for inp in calib_data:
-                    model(*inp)
-
-        def compare_and_validate_results(float_model, q_model):
-            weight_dict = compare_weights_fx(
-                float_model.state_dict(), q_model.state_dict()
-            )
-            self.assertEqual(len(weight_dict), 1)
-            for k, v in weight_dict.items():
-                self.assertTrue(v["float"].shape == v["quantized"].shape)
-
-        float_model = SingleLayerLinearModel()
-        float_model.eval()
-
         qengine = torch.backends.quantized.engine
         qconfig = get_default_qconfig(qengine)
         qconfig_dict = {"": qconfig}
 
+        float_model = SingleLayerLinearModel()
+        float_model.eval()
+
         prepared_model = prepare_fx(float_model, qconfig_dict)
 
         backup_prepared_model = copy.deepcopy(prepared_model)
         backup_prepared_model.eval()
 
         # Run calibration
-        calibrate(prepared_model, self.calib_data)
+        test_only_eval_fn(prepared_model, self.calib_data)
         q_model = convert_fx(prepared_model)
 
-        compare_and_validate_results(backup_prepared_model, q_model)
+        expected_weight_dict_keys = {"fc1._packed_params._packed_params"}
+        self.compare_and_validate_model_weights_results_fx(
+            backup_prepared_model, q_model, expected_weight_dict_keys
+        )
 
-    @skipIfNoFBGEMM
+    @override_qengines
     def test_compare_weights_linear_dynamic_fx(self):
         r"""Compare the weights of float and dynamic quantized linear layer"""
 
-        def compare_and_validate_results(float_model, q_model):
-            weight_dict = compare_weights_fx(
-                float_model.state_dict(), q_model.state_dict()
-            )
-            self.assertEqual(len(weight_dict), 1)
-            for k, v in weight_dict.items():
-                self.assertTrue(len(v["float"]) == len(v["quantized"]))
-                for i, val in enumerate(v["quantized"]):
-                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
+        qconfig = torch.quantization.qconfig.default_dynamic_qconfig
+        qconfig_dict = {"": qconfig}
 
         float_model = SingleLayerLinearDynamicModel()
         float_model.eval()
 
-        qconfig = torch.quantization.qconfig.default_dynamic_qconfig
-        qconfig_dict = {"": qconfig}
-
         prepared_model = prepare_fx(float_model, qconfig_dict)
 
         backup_prepared_model = copy.deepcopy(prepared_model)
@@ -141,4 +127,84 @@
 
         q_model = convert_fx(prepared_model)
 
-        compare_and_validate_results(backup_prepared_model, q_model)
+        expected_weight_dict_keys = {"fc1._packed_params._packed_params"}
+        self.compare_and_validate_model_weights_results_fx(
+            backup_prepared_model, q_model, expected_weight_dict_keys
+        )
+
+    def compare_and_validate_model_stub_results_fx(
+        self, float_model, q_model, module_swap_list, data, expected_ob_dict_keys
+    ):
+        ob_dict = compare_model_stub_fx(float_model, q_model, module_swap_list, data)
+
+        self.assertTrue(ob_dict.keys() == expected_ob_dict_keys)
+        self.assertEqual(len(ob_dict), 1)
+
+        for k, v in ob_dict.items():
+            self.assertTrue(len(v["float"]) == len(v["quantized"]))
+            for i, val in enumerate(v["quantized"]):
+                self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
+
+
+    @override_qengines
+    def test_compare_model_stub_conv_static_fx(self):
+        r"""Compare the output of static quantized conv layer and its float shadow module"""
+
+        qengine = torch.backends.quantized.engine
+        qconfig = get_default_qconfig(qengine)
+        qconfig_dict = {"": qconfig}
+
+        model_list = [ConvModel(), ConvBnReLUModel()]
+
+        for float_model in model_list:
+            float_model.eval()
+
+            prepared_model = prepare_fx(float_model, qconfig_dict)
+
+            backup_prepared_model = copy.deepcopy(prepared_model)
+
+            # Run calibration
+            test_only_eval_fn(prepared_model, self.img_data_2d)
+            q_model = convert_fx(prepared_model)
+
+            module_swap_list = [nn.Conv2d, nni.modules.fused.ConvReLU2d]
+
+            expected_ob_dict_keys = {"conv.stats"}
+            self.compare_and_validate_model_stub_results_fx(
+                backup_prepared_model,
+                q_model,
+                module_swap_list,
+                self.img_data_2d[0][0],
+                expected_ob_dict_keys,
+            )
+
+    @override_qengines
+    def test_compare_model_stub_linear_static_fx(self):
+        r"""Compare the output of static quantized linear layer and its float shadow module"""
+
+        qengine = torch.backends.quantized.engine
+        qconfig = get_default_qconfig(qengine)
+        qconfig_dict = {"": qconfig}
+
+        float_model = SingleLayerLinearModel()
+        float_model.eval()
+
+        prepared_model = prepare_fx(float_model, qconfig_dict)
+
+        backup_prepared_model = copy.deepcopy(prepared_model)
+
+        # Run calibration
+        test_only_eval_fn(prepared_model, self.calib_data)
+        q_model = convert_fx(prepared_model)
+
+        linear_data = self.calib_data[0][0]
+        module_swap_list = [nn.Linear]
+
+        expected_ob_dict_keys = {"fc1.stats"}
+        self.compare_and_validate_model_stub_results_fx(
+            backup_prepared_model,
+            q_model,
+            module_swap_list,
+            linear_data,
+            expected_ob_dict_keys,
+        )
diff --git a/torch/quantization/_numeric_suite_fx.py b/torch/quantization/_numeric_suite_fx.py
index aeba95b..cfb806e 100644
--- a/torch/quantization/_numeric_suite_fx.py
+++ b/torch/quantization/_numeric_suite_fx.py
@@ -7,6 +7,12 @@
 from torch.fx import GraphModule  # type: ignore
 from torch.fx import map_arg  # type: ignore
 from torch.fx.graph import Graph
+from torch.quantization._numeric_suite import (
+    get_logger_dict,
+    prepare_model_with_stubs,
+    compare_weights,
+    ShadowLogger,
+)
 from torch.quantization.fx.quantize import _remove_qconfig, is_activation_post_process
 
 
@@ -45,34 +51,6 @@
     return model
 
 
-def _find_match(str_list, key_str, postfix):
-    split_str = key_str.split(".")
-    if split_str[-1] == postfix:
-        match_string = "".join(key_str.split(".")[0:-1])
-        for s2 in str_list:
-            pattern1 = "".join(s2.split(".")[0:-1])
-            pattern2 = "".join(s2.split(".")[0:-2])
-            if match_string == pattern1:
-                return s2
-            if match_string == pattern2:
-                return s2
-
-        # For matching "fc.weight" and "fc._packed_params._packed_params"
-        if postfix == "_packed_params":
-            match_string = "".join(key_str.split(".")[0:-2])
-            if len(match_string) == 0:
-                return None
-            for s2 in str_list:
-                pattern1 = "".join(s2.split(".")[0:-1])
-                pattern2 = "".join(s2.split(".")[0:-2])
-                if match_string == pattern1:
-                    return s2
-                if match_string == pattern2:
-                    return s2
-    else:
-        return None
-
-
 def compare_weights_fx(float_dict, quantized_dict):
     r"""Compare the weights of the float module with its corresponding quantized
     module. Return a dict with key corresponding to module names and each entry being
@@ -86,7 +64,7 @@
         quantized_model = convert_fx(prepared_model)
 
         qmodel = quantized_model
-        wt_compare_dict = compare_weights(backup_prepared_model.state_dict(), qmodel.state_dict())
+        wt_compare_dict = compare_weights_fx(backup_prepared_model.state_dict(), qmodel.state_dict())
         for key in wt_compare_dict:
             print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
 
@@ -102,38 +80,69 @@
     torch._C._log_api_usage_once(
         "quantization_api._numeric_suite_fx.compare_weights_fx"
     )
-    weight_dict: Dict[str, Dict] = {}
-    for key in quantized_dict:
-        match_key = _find_match(float_dict, key, "weight")
-        if match_key is not None:
-            weight_dict[key] = {}
-            weight_dict[key]["float"] = float_dict[match_key]
-            weight_dict[key]["quantized"] = quantized_dict[key]
-            continue
+    return compare_weights(float_dict, quantized_dict)
 
-        # For matching "fc.weight" and "fc._packed_params._packed_params"
-        match_key = _find_match(float_dict, key, "_packed_params")
-        if match_key is not None:
-            weight_dict[key] = {}
-            weight_dict[key]["float"] = float_dict[match_key]
-            weight_dict[key]["quantized"] = quantized_dict[key][0]
 
-        # For LSTM
-        split_str = key.split(".")
-        if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
-            layer = split_str[-2]
-            module_name = ".".join(split_str[:-3])
-            float_weight_ih_key = module_name + ".weight_ih_l" + layer
-            float_weight_hh_key = module_name + ".weight_hh_l" + layer
-            if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
-                weight_dict[key] = {}
-                weight_dict[key]["float"] = float_dict[float_weight_ih_key]
-                weight_dict[key]["quantized"] = (
-                    quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
-                )
-                weight_dict[key]["float"] = float_dict[float_weight_hh_key]
-                weight_dict[key]["quantized"] = (
-                    quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
-                )
+def prepare_model_with_stubs_fx(float_module, q_module, module_swap_list, Logger):
+    r"""Prepare the model by attaching the float module to its matching quantized
+    module as the shadow if the float module type is in module_swap_list.
 
-    return weight_dict
+    Example usage:
+        prepare_model_with_stubs_fx(float_model, q_model, module_swap_list, Logger)
+        q_model(data)
+        ob_dict = get_logger_dict(q_model)
+
+    Args:
+        float_module: float module used to generate the q_module
+        q_module: module quantized from float_module
+        module_swap_list: list of float module types to attach the shadow
+        Logger: type of logger to be used in shadow module to process the outputs of
+            quantized module and its float shadow module
+    """
+    torch._C._log_api_usage_once(
+        "quantization_api._numeric_suite.prepare_model_with_stubs_fx"
+    )
+    return prepare_model_with_stubs(float_module, q_module, module_swap_list, Logger)
+
+
+def compare_model_stub_fx(
+    float_model, q_model, module_swap_list, *data, Logger=ShadowLogger
+):
+    r"""Compare quantized module in a model with its floating point counterpart,
+    feeding both of them the same input. Return a dict with key corresponding to
+    module names and each entry being a dictionary with two keys 'float' and
+    'quantized', containing the output tensors of quantized and its matching
+    float shadow module. This dict can be used to compare and compute the module
+    level quantization error.
+
+    This function first call prepare_model_with_stubs_fx() to swap the quantized
+    module that we want to compare with the Shadow module, which takes quantized
+    module, corresponding float module and logger as input, and creates a forward
+    path inside to make the float module to shadow quantized module sharing the
+    same input. The logger can be customizable, default logger is ShadowLogger
+    and it will save the outputs of the quantized module and float module that
+    can be used to compute the module level quantization error.
+
+    Example usage:
+        module_swap_list = [nn.Linear]
+        ob_dict = compare_model_stub_fx(float_model,qmodel,module_swap_list, data)
+        for key in ob_dict:
+            print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
+
+    Args:
+        float_model: float model used to generate the q_model
+        q_model: model quantized from float_model
+        module_swap_list: list of float module types at which shadow modules will
+            be attached.
+        data: input data used to run the prepared q_model
+        Logger: type of logger to be used in shadow module to process the outputs of
+            quantized module and its float shadow module
+    """
+    torch._C._log_api_usage_once(
+        "quantization_api._numeric_suite.compare_model_stub_fx"
+    )
+    float_model = remove_qconfig_observer_fx(float_model)
+    prepare_model_with_stubs_fx(float_model, q_model, module_swap_list, Logger)
+    q_model(*data)
+    ob_dict = get_logger_dict(q_model)
+    return ob_dict
diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py
index 64c9c91..459bc4b 100644
--- a/torch/testing/_internal/common_quantization.py
+++ b/torch/testing/_internal/common_quantization.py
@@ -910,6 +910,19 @@
         x = self.dequant(x)
         return x
 
+class ConvBnReLUModel(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
+        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+        return x
+
 class AnnotatedConvBnReLUModel(torch.nn.Module):
     def __init__(self, qengine='fbgemm'):
         super(AnnotatedConvBnReLUModel, self).__init__()