blob: 1b528d7677d3cb8751caf3dc83498d5bb8862234 [file] [log] [blame]
import torch
from collections import OrderedDict
from typing import Union, Callable, Any, Dict, Tuple, Set
import re
from torch.fx.graph import (
Graph,
)
from .utils import _parent_name
# TODO: move this to quantization_types.py?
QConfigAny = Union[torch.quantization.QConfig,
torch.quantization.QConfigDynamic, None]
def get_flattened_qconfig_dict(qconfig_dict):
""" flatten the global, object_type and module_name qconfig
to the same qconfig_dict so that it can be used by
propagate_qconfig_ function.
"module_name_regex" is ignored for now since it's not supported
in propagate_qconfig_, but it can be fixed later.
For example:
Input: {
"": qconfig,
"object_type": [
(torch.add, qconfig)
],
"module_name": [
("conv", qconfig)
]
}
Output: {
"": qconfig,
torch.add: qconfig,
"conv": qconfig
}
"""
flattened = dict()
if '' in qconfig_dict:
flattened[''] = qconfig_dict['']
def flatten_key(key):
if key in qconfig_dict:
for (obj, qconfig) in qconfig_dict[key].items():
flattened[obj] = qconfig
flatten_key('object_type')
flatten_key('module_name')
return flattened
def convert_dict_to_ordered_dict(qconfig_dict: Any) -> Dict[str, Dict[Any, Any]]:
""" Convert dict in qconfig_dict to ordered dict
"""
# convert a qconfig list for a type to OrderedDict
def _convert_to_ordered_dict(key, qconfig_dict):
qconfig_dict[key] = OrderedDict(qconfig_dict.get(key, []))
_convert_to_ordered_dict('object_type', qconfig_dict)
_convert_to_ordered_dict('module_name_regex', qconfig_dict)
_convert_to_ordered_dict('module_name', qconfig_dict)
return qconfig_dict
def get_object_type_qconfig(
qconfig_dict: Any,
object_type: Union[Callable, str],
fallback_qconfig: QConfigAny) -> QConfigAny:
# object_type can be
# 1. module type (call_module)
# 2. function (call_function)
# 3. string (call_method)
return qconfig_dict['object_type'].get(
object_type, fallback_qconfig)
def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig):
for regex_pattern, qconfig in \
qconfig_dict['module_name_regex'].items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
return fallback_qconfig
def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig):
if module_name == '':
# module name qconfig not found
return fallback_qconfig
if module_name in qconfig_dict['module_name']:
return qconfig_dict['module_name'][module_name]
else:
parent, _ = _parent_name(module_name)
return get_module_name_qconfig(qconfig_dict, parent, fallback_qconfig)
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
def get_qconfig(qconfig_dict, module_type, module_name, global_qconfig):
module_type_qconfig = get_object_type_qconfig(
qconfig_dict, module_type, global_qconfig)
module_name_regex_qconfig = get_module_name_regex_qconfig(
qconfig_dict, module_name, module_type_qconfig)
module_name_qconfig = get_module_name_qconfig(
qconfig_dict, module_name, module_name_regex_qconfig)
return module_name_qconfig
def generate_qconfig_map(
root: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
input_graph: Graph,
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]:
global_qconfig = qconfig_dict.get("", None)
qconfig_map = dict()
for node in input_graph.nodes:
qconfig = None
if node.op == "get_attr":
module_name, _ = _parent_name(node.target)
qconfig = get_qconfig(
qconfig_dict, type(modules[module_name]), module_name, global_qconfig)
elif node.op == "call_function":
# precedence: module_name_qconfig
# > function_qconfig > global_qconfig
# module_name takes precedence over function qconfig
function_qconfig = get_object_type_qconfig(
qconfig_dict, node.target, global_qconfig)
module_path, module_type = node_name_to_scope[node.name]
qconfig = get_qconfig(
qconfig_dict, module_type, module_path, function_qconfig)
elif node.op == "call_method":
module_path, module_type = node_name_to_scope[node.name]
# use the qconfig of the module that the node belongs to
qconfig = get_qconfig(
qconfig_dict, module_type, module_path, global_qconfig)
elif node.op == 'call_module':
qconfig = get_qconfig(
qconfig_dict, type(modules[node.target]), node.target, global_qconfig)
# regex is not supported eager mode propagate_qconfig_, we'll
# need to set the qconfig explicitly here in case regex
# is used
modules[node.target].qconfig = qconfig
qconfig_map[node.name] = qconfig
return qconfig_map
def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
r""" Checks if the given config_dict has the correct keys
Args:
`config_dict`: dictionary whose keys we want to check
"""
for k in config_dict.keys():
if k not in allowed_keys:
raise ValueError(
'Expected ' + dict_name + ' to have the following keys: ' +
str(allowed_keys) + '. But found \'' + k +
'\' instead.')
def check_is_valid_qconfig_dict(qconfig_dict: Any) -> None:
r""" Checks if the given qconfig_dict has the correct keys
Args:
`qconfig_dict`: dictionary whose keys we want to check
"""
qconfig_dict_allowed_keys = {"", "object_type", "module_name_regex", "module_name"}
check_is_valid_config_dict(qconfig_dict, qconfig_dict_allowed_keys, "qconfig_dict")
def check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict: Dict[str, Any] = None) -> None:
r""" Checks if the given prepare_custom_config_dict has the correct keys
Args:
`prepare_custom_config_dict`: customization configuration dictionary for
quantization tool
"""
if not prepare_custom_config_dict:
return
prepare_custom_config_dict_allowed_keys = {"standalone_module_name",
"standalone_module_class",
"float_to_observed_custom_module_class",
"non_traceable_module_name",
"non_traceable_module_class",
"additional_fuser_method_mapping",
"additional_qat__module_mapping",
"additional_fusion_pattern",
"additional_quant_pattern",
"input_quantized_idxs",
"output_quantized_idxs",
"preserved_attributes"}
check_is_valid_config_dict(prepare_custom_config_dict,
prepare_custom_config_dict_allowed_keys, "prepare_custom_config_dict")
def check_is_valid_convert_custom_config_dict(convert_custom_config_dict: Dict[str, Any] = None) -> None:
r""" Checks if the given convert_custom_config_dict has the correct keys
Args:
`convert_custom_config_dict`: dictionary for custom configurations for
convert function
"""
if not convert_custom_config_dict:
return
convert_custom_config_dict_allowed_keys = {"additional_object_mapping",
"observed_to_quantized_custom_module_class",
"preserved_attributes"}
check_is_valid_config_dict(convert_custom_config_dict,
convert_custom_config_dict_allowed_keys, "convert_custom_config_dict")
def check_is_valid_fuse_custom_config_dict(fuse_custom_config_dict: Dict[str, Any] = None) -> None:
r""" Checks if the given fuse_custom_config_dict has the correct keys
Args:
`fuse_custom_config_dict`: dictionary for custom configurations for fuse_fx
"""
if not fuse_custom_config_dict:
return
fuse_custom_config_dict_allowed_keys = {"additional_fuser_method_mapping",
"preserved_attributes"}
check_is_valid_config_dict(fuse_custom_config_dict, fuse_custom_config_dict_allowed_keys, "fuse_custom_config_dict")