[quant][fx][graphmode][refactor] Factor out generate_qconfig_map to qconfig_utils.py (#58453)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58453

Move the class method generate_qconfig_map to qconfig_utils, will add more PRs
to remove functions out of Quantizer and eventually remove the Quantizer object

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D28497965

fbshipit-source-id: 3c78cfe676965d20a8834a859ffed4d8e9ecade4
diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py
index 2b2938c..d9d1161 100644
--- a/torch/quantization/fx/qconfig_utils.py
+++ b/torch/quantization/fx/qconfig_utils.py
@@ -1,8 +1,12 @@
 import torch
 from collections import OrderedDict
-from typing import Union, Callable, Any, Dict
+from typing import Union, Callable, Any, Dict, Tuple
 import re
 
+from torch.fx.graph import (
+    Graph,
+)
+
 from .utils import _parent_name
 
 QConfigAny = Union[torch.quantization.QConfig,
@@ -97,3 +101,41 @@
     module_name_qconfig = get_module_name_qconfig(
         qconfig_dict, module_name, module_name_regex_qconfig)
     return module_name_qconfig
+
+def generate_qconfig_map(
+        root: torch.nn.Module,
+        modules: Dict[str, torch.nn.Module],
+        input_graph: Graph,
+        qconfig_dict: Any,
+        node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]:
+    global_qconfig = qconfig_dict.get("", None)
+    qconfig_map = dict()
+    for node in input_graph.nodes:
+        qconfig = None
+        if node.op == "get_attr":
+            module_name, _ = _parent_name(node.target)
+            qconfig = get_qconfig(
+                qconfig_dict, type(modules[module_name]), module_name, global_qconfig)
+        elif node.op == "call_function":
+            # precedence: module_name_qconfig
+            # > function_qconfig > global_qconfig
+            # module_name takes precedence over function qconfig
+            function_qconfig = get_object_type_qconfig(
+                qconfig_dict, node.target, global_qconfig)
+            module_path, module_type = node_name_to_scope[node.name]
+            qconfig = get_qconfig(
+                qconfig_dict, module_type, module_path, function_qconfig)
+        elif node.op == "call_method":
+            module_path, module_type = node_name_to_scope[node.name]
+            # use the qconfig of the module that the node belongs to
+            qconfig = get_qconfig(
+                qconfig_dict, module_type, module_path, global_qconfig)
+        elif node.op == 'call_module':
+            qconfig = get_qconfig(
+                qconfig_dict, type(modules[node.target]), node.target, global_qconfig)
+            # regex is not supported eager mode propagate_qconfig_, we'll
+            # need to set the qconfig explicitly here in case regex
+            # is used
+            modules[node.target].qconfig = qconfig
+        qconfig_map[node.name] = qconfig
+    return qconfig_map
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index a6da3f9..5eb96cf 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -77,9 +77,8 @@
 
 from .qconfig_utils import (
     convert_dict_to_ordered_dict,
+    generate_qconfig_map,
     get_flattened_qconfig_dict,
-    get_object_type_qconfig,
-    get_qconfig,
     QConfigAny,
 )
 
@@ -94,7 +93,6 @@
 # ------------------------
 # Helper Functions
 # ------------------------
-
 def get_standalone_module_configs(
     node: Node,
     modules: Dict[str, torch.nn.Module],
@@ -965,46 +963,6 @@
             get_default_qat_module_mappings(), additional_qat_module_mapping)
         convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
 
-    def _generate_qconfig_map(
-            self,
-            root: torch.nn.Module,
-            input_graph: Graph,
-            qconfig_dict: Any,
-            node_name_to_scope: Dict[str, Tuple[str, type]]) -> None:
-        global_qconfig = qconfig_dict.get("", None)
-        self.node_name_to_scope = node_name_to_scope
-        self.qconfig_map = dict()
-        for node in input_graph.nodes:
-            if node.op == "get_attr":
-                module_name, _ = _parent_name(node.target)
-                self.qconfig_map[node.name] = get_qconfig(
-                    qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig)
-            elif node.op == "call_function":
-                # precedence: [TODO] module_name_qconfig (need scope support
-                # from fx)
-                # > function_qconfig > global_qconfig
-                # module_name takes precedence over function qconfig
-                function_qconfig = get_object_type_qconfig(
-                    qconfig_dict, node.target, global_qconfig)
-                module_path, module_type = node_name_to_scope[node.name]
-                qconfig = get_qconfig(
-                    qconfig_dict, module_type, module_path, function_qconfig)
-                self.qconfig_map[node.name] = qconfig
-            elif node.op == "call_method":
-                module_path, module_type = node_name_to_scope[node.name]
-                # use the qconfig of the module that the node belongs to
-                qconfig = get_qconfig(
-                    qconfig_dict, module_type, module_path, global_qconfig)
-                self.qconfig_map[node.name] = qconfig
-            elif node.op == 'call_module':
-                module_qconfig = get_qconfig(
-                    qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig)
-                # regex is not supported eager mode propagate_qconfig_, we'll
-                # need to set the qconfig explicitly here in case regex
-                # is used
-                self.modules[node.target].qconfig = module_qconfig
-                self.qconfig_map[node.name] = module_qconfig
-
     def _prepare(
             self,
             model: GraphModule,
@@ -1049,8 +1007,9 @@
 
         self.modules = dict(model.named_modules())
 
+        self.node_name_to_scope = node_name_to_scope
         # fill self.qconfig_map, a map from node name to qconfig, used in _find_matches
-        self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope)
+        self.qconfig_map = generate_qconfig_map(model, self.modules, model.graph, qconfig_dict, node_name_to_scope)
 
         # match the patterns that will get quantized
         standalone_module_name_configs = prepare_custom_config_dict.get(