[quant][fx][graphmode][refactor] Factor out generate_qconfig_map to qconfig_utils.py (#58453)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58453
Move the class method generate_qconfig_map to qconfig_utils, will add more PRs
to remove functions out of Quantizer and eventually remove the Quantizer object
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
Imported from OSS
Reviewed By: vkuzo
Differential Revision: D28497965
fbshipit-source-id: 3c78cfe676965d20a8834a859ffed4d8e9ecade4
diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py
index 2b2938c..d9d1161 100644
--- a/torch/quantization/fx/qconfig_utils.py
+++ b/torch/quantization/fx/qconfig_utils.py
@@ -1,8 +1,12 @@
import torch
from collections import OrderedDict
-from typing import Union, Callable, Any, Dict
+from typing import Union, Callable, Any, Dict, Tuple
import re
+from torch.fx.graph import (
+ Graph,
+)
+
from .utils import _parent_name
QConfigAny = Union[torch.quantization.QConfig,
@@ -97,3 +101,41 @@
module_name_qconfig = get_module_name_qconfig(
qconfig_dict, module_name, module_name_regex_qconfig)
return module_name_qconfig
+
+def generate_qconfig_map(
+ root: torch.nn.Module,
+ modules: Dict[str, torch.nn.Module],
+ input_graph: Graph,
+ qconfig_dict: Any,
+ node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]:
+ global_qconfig = qconfig_dict.get("", None)
+ qconfig_map = dict()
+ for node in input_graph.nodes:
+ qconfig = None
+ if node.op == "get_attr":
+ module_name, _ = _parent_name(node.target)
+ qconfig = get_qconfig(
+ qconfig_dict, type(modules[module_name]), module_name, global_qconfig)
+ elif node.op == "call_function":
+ # precedence: module_name_qconfig
+ # > function_qconfig > global_qconfig
+ # module_name takes precedence over function qconfig
+ function_qconfig = get_object_type_qconfig(
+ qconfig_dict, node.target, global_qconfig)
+ module_path, module_type = node_name_to_scope[node.name]
+ qconfig = get_qconfig(
+ qconfig_dict, module_type, module_path, function_qconfig)
+ elif node.op == "call_method":
+ module_path, module_type = node_name_to_scope[node.name]
+ # use the qconfig of the module that the node belongs to
+ qconfig = get_qconfig(
+ qconfig_dict, module_type, module_path, global_qconfig)
+ elif node.op == 'call_module':
+ qconfig = get_qconfig(
+ qconfig_dict, type(modules[node.target]), node.target, global_qconfig)
+ # regex is not supported eager mode propagate_qconfig_, we'll
+ # need to set the qconfig explicitly here in case regex
+ # is used
+ modules[node.target].qconfig = qconfig
+ qconfig_map[node.name] = qconfig
+ return qconfig_map
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index a6da3f9..5eb96cf 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -77,9 +77,8 @@
from .qconfig_utils import (
convert_dict_to_ordered_dict,
+ generate_qconfig_map,
get_flattened_qconfig_dict,
- get_object_type_qconfig,
- get_qconfig,
QConfigAny,
)
@@ -94,7 +93,6 @@
# ------------------------
# Helper Functions
# ------------------------
-
def get_standalone_module_configs(
node: Node,
modules: Dict[str, torch.nn.Module],
@@ -965,46 +963,6 @@
get_default_qat_module_mappings(), additional_qat_module_mapping)
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
- def _generate_qconfig_map(
- self,
- root: torch.nn.Module,
- input_graph: Graph,
- qconfig_dict: Any,
- node_name_to_scope: Dict[str, Tuple[str, type]]) -> None:
- global_qconfig = qconfig_dict.get("", None)
- self.node_name_to_scope = node_name_to_scope
- self.qconfig_map = dict()
- for node in input_graph.nodes:
- if node.op == "get_attr":
- module_name, _ = _parent_name(node.target)
- self.qconfig_map[node.name] = get_qconfig(
- qconfig_dict, type(self.modules[module_name]), module_name, global_qconfig)
- elif node.op == "call_function":
- # precedence: [TODO] module_name_qconfig (need scope support
- # from fx)
- # > function_qconfig > global_qconfig
- # module_name takes precedence over function qconfig
- function_qconfig = get_object_type_qconfig(
- qconfig_dict, node.target, global_qconfig)
- module_path, module_type = node_name_to_scope[node.name]
- qconfig = get_qconfig(
- qconfig_dict, module_type, module_path, function_qconfig)
- self.qconfig_map[node.name] = qconfig
- elif node.op == "call_method":
- module_path, module_type = node_name_to_scope[node.name]
- # use the qconfig of the module that the node belongs to
- qconfig = get_qconfig(
- qconfig_dict, module_type, module_path, global_qconfig)
- self.qconfig_map[node.name] = qconfig
- elif node.op == 'call_module':
- module_qconfig = get_qconfig(
- qconfig_dict, type(self.modules[node.target]), node.target, global_qconfig)
- # regex is not supported eager mode propagate_qconfig_, we'll
- # need to set the qconfig explicitly here in case regex
- # is used
- self.modules[node.target].qconfig = module_qconfig
- self.qconfig_map[node.name] = module_qconfig
-
def _prepare(
self,
model: GraphModule,
@@ -1049,8 +1007,9 @@
self.modules = dict(model.named_modules())
+ self.node_name_to_scope = node_name_to_scope
# fill self.qconfig_map, a map from node name to qconfig, used in _find_matches
- self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope)
+ self.qconfig_map = generate_qconfig_map(model, self.modules, model.graph, qconfig_dict, node_name_to_scope)
# match the patterns that will get quantized
standalone_module_name_configs = prepare_custom_config_dict.get(