| import torch |
| import operator |
| from torch.fx import ( |
| GraphModule, |
| ) |
| |
| from torch.quantization import ( |
| propagate_qconfig_, |
| ) |
| from torch.fx.graph import ( |
| Graph, |
| Node, |
| ) |
| from torch.fx.node import Argument |
| |
| from ..qconfig import QConfigAny |
| from .qconfig_utils import ( |
| convert_dict_to_ordered_dict, |
| generate_qconfig_map, |
| get_flattened_qconfig_dict, |
| ) |
| |
| from .quantization_patterns import ( |
| QuantizeHandler, |
| CustomModuleQuantizeHandler, |
| StandaloneModuleQuantizeHandler, |
| ) |
| |
| from .quantization_types import Pattern |
| |
| from ._equalize import ( |
| is_equalization_observer, |
| node_supports_equalization, |
| ) |
| |
| from .graph_module import ( |
| ObservedGraphModule, |
| ObservedStandaloneGraphModule, |
| ) |
| |
| from .pattern_utils import ( |
| MatchResult, |
| get_default_quant_patterns, |
| get_default_output_activation_post_process_map, |
| ) |
| |
| from .match_utils import ( |
| find_matches, |
| ) |
| |
| from .utils import ( |
| _parent_name, |
| get_custom_module_class_keys, |
| all_node_args_have_no_tensors, |
| assert_and_get_unique_device, |
| node_bool_tensor_arg_indexes, |
| get_new_attr_name_with_prefix, |
| NON_QUANTIZABLE_WEIGHT_OPS, |
| WEIGHT_INDEX_DICT, |
| FUNCTIONAL_OPS_WITH_BIAS, |
| ) |
| |
| from ..fuser_method_mappings import DEFAULT_OP_LIST_TO_FUSER_METHOD |
| |
| from ..quantization_mappings import ( |
| get_default_qat_module_mappings, |
| ) |
| |
| from ..quantize import ( |
| is_activation_post_process, |
| convert |
| ) |
| |
| from ..utils import ( |
| get_combined_dict, |
| get_qconfig_dtypes, |
| get_swapped_custom_module_class, |
| weight_is_quantized, |
| activation_is_statically_quantized, |
| activation_is_int8_quantized, |
| activation_dtype, |
| weight_dtype, |
| ) |
| |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| |
| def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool: |
| return node.op == "call_module" and \ |
| is_activation_post_process(modules[str(node.target)]) |
| |
| def node_arg_is_weight(node: Node, arg: Any) -> bool: |
| if isinstance(node, Node) and node.op == 'call_function' and \ |
| node.target in WEIGHT_INDEX_DICT: |
| for i, node_arg in enumerate(node.args): |
| if arg is node_arg and i in \ |
| WEIGHT_INDEX_DICT[node.target]: # type: ignore[index] |
| return True |
| for kwarg_name, kwarg_value in node.kwargs.items(): |
| if kwarg_name == 'weight' and arg is kwarg_value: |
| return True |
| return False |
| |
| CONV_OPS_WITH_BIAS = { |
| torch.nn.functional.conv1d, |
| torch.nn.functional.conv2d, |
| torch.nn.functional.conv3d, |
| } |
| CONV_BIAS_ARG_INDEX = 2 |
| |
| def node_arg_is_bias(node: Node, arg: Any) -> bool: |
| if isinstance(node, Node) and node.op == 'call_function': |
| if node.target in CONV_OPS_WITH_BIAS: |
| for i, node_arg in enumerate(node.args): |
| if arg is node_arg and i == CONV_BIAS_ARG_INDEX: |
| return True |
| elif node.target in FUNCTIONAL_OPS_WITH_BIAS: |
| for kwarg_name, kwarg_value in node.kwargs.items(): |
| if kwarg_name == 'bias' and arg is kwarg_value: |
| return True |
| return False |
| |
| def get_standalone_module_configs( |
| node: Node, |
| modules: Dict[str, torch.nn.Module], |
| prepare_custom_config_dict: Dict[str, Any], |
| qconfig: QConfigAny, |
| ) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| """ |
| Returns the standalone module qconfig_dict and prepare_config_dict |
| for `node`, assuming that the module pointed to by `node` is |
| a standalone modules. |
| """ |
| standalone_module = modules[node.target] # type: ignore[index] |
| standalone_module_name_configs = \ |
| prepare_custom_config_dict.get("standalone_module_name", []) |
| standalone_module_class_configs = \ |
| prepare_custom_config_dict.get("standalone_module_class", []) |
| class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs} |
| name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs} |
| config = class_config_map.get(type(standalone_module), (None, None)) |
| config = name_config_map.get(node.target, config) |
| sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0] |
| sm_prepare_config_dict = {} if config[1] is None else config[1] |
| return sm_qconfig_dict, sm_prepare_config_dict |
| |
| def qat_swap_modules( |
| root: torch.nn.Module, |
| additional_qat_module_mapping: Dict[Callable, Callable]) -> None: |
| all_mappings = get_combined_dict( |
| get_default_qat_module_mappings(), additional_qat_module_mapping) |
| convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False) |
| |
| def update_qconfig_for_qat( |
| qconfig_dict: Any, |
| additional_qat_module_mapping: Dict[Callable, Callable] |
| ) -> Any: |
| """ |
| Update the qconfig_dict to account for module swaps during QAT. |
| During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. |
| """ |
| all_qat_mappings = get_combined_dict( |
| get_default_qat_module_mappings(), additional_qat_module_mapping) |
| object_type_dict = qconfig_dict.get("object_type", None) |
| for k, v in object_type_dict.items(): |
| if k in all_qat_mappings: |
| object_type_dict[all_qat_mappings[k]] = v |
| return qconfig_dict |
| |
| def update_qconfig_for_fusion( |
| model: GraphModule, |
| qconfig_dict: Any, |
| ) -> Any: |
| """ |
| Update the qconfig_dict to account for fused modules such as LinearReLU. |
| """ |
| object_type_dict = qconfig_dict.get("object_type", None) |
| if object_type_dict is None: |
| return qconfig_dict |
| |
| modules = dict(model.named_modules()) |
| |
| for node in model.graph.nodes: |
| if node.op == 'call_module': |
| module_type = type(modules[str(node.target)]) |
| if module_type not in list(DEFAULT_OP_LIST_TO_FUSER_METHOD.values()): |
| continue |
| |
| for ops, fuser in DEFAULT_OP_LIST_TO_FUSER_METHOD.items(): |
| if module_type == fuser: |
| fused_qconfig = object_type_dict.get(ops[0], None) |
| |
| # Raise an error if the modules in the fused module have |
| # different qconfigs specified in the qconfig_dict |
| for op in ops: |
| if object_type_dict.get(op, None) != fused_qconfig: |
| raise LookupError("During fusion, we need to specify the same " + |
| f"qconfigs for both modules in {module_type}.") |
| |
| if fused_qconfig is not None: |
| object_type_dict[module_type] = fused_qconfig |
| |
| return qconfig_dict |
| |
| def insert_observer( |
| node: Node, |
| observer: torch.quantization.ObserverBase, |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| graph: Graph, |
| ) -> Node: |
| """ |
| Attaches `observer` to `model`, and creates a node which calls |
| `observer` on the output of `node`. |
| """ |
| model_device = assert_and_get_unique_device(model) |
| if model_device: |
| observer.to(model_device) |
| # add observer module as attribute |
| if is_equalization_observer(observer): |
| prefix = node.name + '_equalization_process_' |
| else: |
| prefix = node.name + '_activation_post_process_' |
| get_new_observer_name = get_new_attr_name_with_prefix(prefix) |
| observer_name = get_new_observer_name(model) |
| setattr(model, observer_name, observer) |
| modules[observer_name] = observer |
| with graph.inserting_after(node): |
| new_obs = graph.create_node( |
| 'call_module', observer_name, (node,), {}) |
| return new_obs |
| |
| def get_target_activation_dtype_for_node( |
| node: Node, |
| qconfig: QConfigAny, |
| inputs_seen_counter: int, |
| outputs_seen_counter: int, |
| input_quantized_idxs: List[int], |
| output_quantized_idxs: List[int], |
| qhandler: Optional[QuantizeHandler], |
| modules: Dict[str, torch.nn.Module], |
| cache_for_no_tensor_check: Dict[Node, bool], |
| ) -> Optional[torch.dtype]: |
| """ |
| Returns the expected dtype of the input and output of this node after |
| convert. If the value is not None, it represents the dtype of the |
| Tensor. If the value is None, it means the value is not a Tensor. |
| |
| Note: this is for activations only, weight dtypes are not handled here. |
| |
| TODO(future PR, if needed): explicitly spell out the non-Tensor |
| dtypes. |
| """ |
| if node.op == 'placeholder': |
| if inputs_seen_counter in input_quantized_idxs: |
| return torch.quint8 |
| else: |
| # if dtype is fp32 (default), do nothing |
| # note: other dtypes are not supported |
| return torch.float |
| |
| elif node.op in ('call_module', 'call_method', 'call_function'): |
| args_have_no_tensors = \ |
| all_node_args_have_no_tensors( |
| node, modules, cache_for_no_tensor_check) |
| if args_have_no_tensors: |
| return None |
| |
| # TODO(future PR): consider stopping matching getitem |
| is_getitem = node.op == 'call_function' and \ |
| node.target == operator.getitem |
| if is_getitem: |
| return torch.float |
| |
| # get qconfig to determine the eventual dtype of this node |
| if qconfig is not None: |
| if qhandler is not None and qhandler.input_output_observed(): |
| act_dtype, weight_dtype, act_compute_dtype = \ |
| get_qconfig_dtypes(qconfig) |
| return act_dtype |
| else: |
| return torch.float |
| else: |
| return torch.float |
| |
| elif node.op == 'get_attr': |
| return torch.float |
| |
| elif node.op == 'output': |
| if outputs_seen_counter in output_quantized_idxs: |
| return torch.quint8 |
| else: |
| # if dtype is fp32 (default), do nothing |
| # note: other dtypes are not supported |
| return torch.float |
| |
| else: |
| raise AssertionError(f'need to handle {node.format_node()}') |
| |
| def maybe_insert_input_observer_for_arg_or_kwarg( |
| node: Union[Node, Any], |
| arg: Argument, |
| qconfig: QConfigAny, |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| graph: Graph, |
| node_name_to_target_dtype: Dict[str, Any], |
| qhandler: Optional[QuantizeHandler], |
| prepare_custom_config_dict: Dict[str, Any], |
| ) -> Argument: |
| """ |
| Given a `node` and an `arg`, inserts an input observer between |
| `node` and `arg` if necessary. |
| """ |
| # for ops such as torch.cat([x0, x1]), |
| # traverse through the list |
| if isinstance(arg, (list, tuple)): |
| new_arg_to_return = [] |
| for inner_arg in arg: |
| new_inner_arg = maybe_insert_input_observer_for_arg_or_kwarg( |
| node, inner_arg, qconfig, model, modules, |
| graph, node_name_to_target_dtype, |
| qhandler, prepare_custom_config_dict) |
| new_arg_to_return.append(new_inner_arg) |
| return new_arg_to_return |
| |
| if not isinstance(arg, Node): |
| return arg |
| assert isinstance(arg, Node) |
| |
| # default (no observer) |
| new_arg = arg |
| |
| is_standalone_module = qhandler is not None and \ |
| isinstance(qhandler, StandaloneModuleQuantizeHandler) |
| |
| if not is_standalone_module: |
| # regular flow for most nodes, except standalone modules |
| is_weight = node_arg_is_weight(node, arg) |
| assert qconfig is not None |
| |
| act_post_process_ctr = qconfig.weight if is_weight else \ |
| qconfig.activation |
| |
| is_bias = node_arg_is_bias(node, arg) |
| is_activation = not (is_weight or is_bias) |
| weight_needs_obs = is_weight and weight_is_quantized(qconfig) and node.target not in NON_QUANTIZABLE_WEIGHT_OPS |
| bias_needs_obs = \ |
| (is_bias and activation_dtype(qconfig) == torch.float16) and \ |
| weight_dtype(qconfig) == torch.float16 |
| |
| arg_dtype = node_name_to_target_dtype[arg.name] |
| node_dtype = node_name_to_target_dtype[node.name] |
| dtype_changes_and_second_dtype_not_float = ( |
| # if the dtypes are different, we need an observer |
| (arg_dtype != node_dtype) and |
| # except if the second dtype is float, a dequant will be inserted |
| # without an observer in convert |
| # TODO(future PR): change this so a placeholder is inserted for |
| # future dequants, to make the logic easier to understand |
| (node_dtype != torch.float) and |
| # if arg is a bool tensor or not a tensor, do not insert observer |
| (arg_dtype not in (torch.bool, None)) and |
| (is_activation and activation_is_statically_quantized(qconfig)) |
| ) |
| |
| needs_obs = ( |
| weight_needs_obs or |
| bias_needs_obs or |
| dtype_changes_and_second_dtype_not_float |
| ) |
| |
| else: |
| # custom flow for standalone modules |
| _sm_qconfig_dict, sm_prepare_config_dict = \ |
| get_standalone_module_configs( |
| node, modules, prepare_custom_config_dict, qconfig) |
| |
| sm_input_quantized_idxs = \ |
| sm_prepare_config_dict.get('input_quantized_idxs', []) |
| # for args, this is set to the index of the current arg |
| # for kwargs, this is left at None |
| cur_input_idx = None |
| for arg_idx, arg_to_check in enumerate(node.args): |
| if arg_to_check is arg: |
| cur_input_idx = arg_idx |
| break |
| |
| if cur_input_idx is None: |
| needs_obs = False |
| else: |
| arg_dtype = node_name_to_target_dtype[arg.name] |
| node_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \ |
| else torch.float |
| needs_obs = ( |
| (arg_dtype != node_dtype) and |
| (node_dtype != torch.float) |
| ) |
| |
| if needs_obs: |
| |
| new_obs_mod = act_post_process_ctr() |
| existing_obs_node = None |
| |
| # Before using the new observer, check if an observer |
| # of the correct type already exists. If it does, use it. |
| # This prevents duplicate observer insertions if a node is |
| # used by multiple nodes. |
| for maybe_obs_node, _ in arg.users.items(): |
| if maybe_obs_node.op == 'call_module': |
| maybe_obs_mod = modules[maybe_obs_node.target] # type: ignore[index] |
| if ( |
| type(maybe_obs_mod) == type(new_obs_mod) and |
| node_name_to_target_dtype[maybe_obs_node.name] == node_dtype |
| ): |
| existing_obs_node = maybe_obs_node |
| break |
| |
| if existing_obs_node is None: |
| new_obs_node = insert_observer( |
| arg, new_obs_mod, model, modules, graph) |
| # set the type, so the next node can read it |
| node_name_to_target_dtype[new_obs_node.name] = node_dtype |
| # override this arg to be the observed arg |
| new_arg = new_obs_node |
| else: |
| new_arg = existing_obs_node |
| |
| return new_arg |
| |
| |
| def maybe_insert_input_observers_for_node( |
| node: Node, |
| qconfig: QConfigAny, |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| graph: Graph, |
| node_name_to_target_dtype: Dict[str, Any], |
| qhandler: Optional[QuantizeHandler], |
| prepare_custom_config_dict: Dict[str, Any], |
| ) -> None: |
| """ |
| If needed, inserts observers to the input args and kwargs of `node`. |
| Note: modifies `node` inplace. |
| |
| For example, if cur_node needs an observer after prev_node, we change from |
| |
| prev_node -> cur_node |
| |
| To |
| |
| prev_node -> obs -> cur_node |
| """ |
| if qconfig is None: |
| # if quantization is turned off for this node, we do not need |
| # to insert input observers |
| return |
| assert qconfig is not None |
| |
| # Look through every input arg. If that arg's target dtype does not |
| # match the current node's target dtype, insert an observer. |
| new_args = [] |
| for arg in node.args: |
| new_arg = maybe_insert_input_observer_for_arg_or_kwarg( |
| node, arg, qconfig, model, modules, graph, |
| node_name_to_target_dtype, |
| qhandler, prepare_custom_config_dict) |
| new_args.append(new_arg) |
| |
| new_kwargs = {} |
| for k, kwarg in node.kwargs.items(): |
| new_kwarg = maybe_insert_input_observer_for_arg_or_kwarg( |
| node, kwarg, qconfig, model, modules, graph, |
| node_name_to_target_dtype, |
| qhandler, prepare_custom_config_dict) |
| new_kwargs[k] = new_kwarg |
| |
| # assign the new args and kwargs to the node, inplace |
| node.args = tuple(new_args) |
| node.kwargs = new_kwargs |
| |
| def maybe_insert_input_equalization_observers_for_node( |
| node: Node, |
| equalization_qconfig: Any, |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| graph: Graph, |
| node_name_to_target_dtype: Dict[str, Any], |
| ) -> None: |
| """ |
| If `node` needs to be equalized, find the input/weight observers it needs in |
| `equalization_qconfig`, creates them, and inserts it into `graph`. |
| |
| If `node` does not need an equalization observer, returns None. |
| """ |
| if equalization_qconfig is None or not node_supports_equalization(node, modules): |
| return |
| |
| new_args = [] |
| for arg in node.args: |
| if not isinstance(arg, Node) or node_arg_is_bias(node, arg): |
| new_args.append(arg) |
| continue |
| |
| is_weight = node_arg_is_weight(node, arg) |
| |
| act_eq_process_ctr = equalization_qconfig.weight if is_weight else \ |
| equalization_qconfig.input_activation |
| |
| new_eq_obs_mod = act_eq_process_ctr() |
| new_eq_obs_node = insert_observer( |
| arg, new_eq_obs_mod, model, modules, graph) |
| |
| # set the type, so the next node can read it |
| node_name_to_target_dtype[new_eq_obs_node.name] = node_name_to_target_dtype[arg.name] |
| |
| new_args.append(new_eq_obs_node) |
| |
| # assign the new args and kwargs to the node, inplace |
| node.args = tuple(new_args) |
| |
| def maybe_insert_output_observer_for_node( |
| node: Node, |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| graph: Graph, |
| matches: Dict[str, MatchResult], |
| node_name_to_target_dtype: Dict[str, Any], |
| matched_pattern: Any, |
| qhandler: Optional[QuantizeHandler], |
| ) -> Optional[Node]: |
| """ |
| If `node` needs an output observer, creates it, inserts it into `graph` |
| and returns it. |
| |
| If `node` does not need an output observer, returns None. |
| """ |
| root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( |
| node.name, (None, None, None, None, None)) |
| |
| if qhandler is None: |
| return None |
| |
| assert qconfig is not None |
| assert node.op != 'output', 'observer insertion for outputs is handled elsewhere' |
| |
| is_standalone_module = qhandler is not None and \ |
| isinstance(qhandler, StandaloneModuleQuantizeHandler) |
| |
| dtype = node_name_to_target_dtype[node.name] |
| should_insert_observer = \ |
| qhandler.should_insert_observer_for_output( |
| qconfig, model.training) and dtype not in (torch.bool, None, torch.float) |
| # TODO(future PR): move the following logic to |
| # should_insert_observer_for_output |
| should_insert_observer = should_insert_observer and \ |
| activation_is_statically_quantized(qconfig) |
| |
| # we never insert observers to output of standalone module, we assume |
| # if needed, they are inserted inside the standalone module |
| should_insert_observer = should_insert_observer and \ |
| (not is_standalone_module) |
| |
| if should_insert_observer: |
| act_post_process_ctr = qconfig.activation |
| if activation_is_int8_quantized(qconfig): |
| act_post_process_ctr = \ |
| get_default_output_activation_post_process_map().get( |
| matched_pattern, |
| act_post_process_ctr) |
| observer = act_post_process_ctr() |
| new_obs = insert_observer(node, observer, model, modules, graph) |
| # set the type, so the next node can read it |
| node_name_to_target_dtype[new_obs.name] = \ |
| node_name_to_target_dtype[node.name] |
| return new_obs |
| else: |
| return None |
| |
| def maybe_insert_observers_before_graph_output( |
| graph_output_node: Node, |
| output_quantized_idxs: List[int], |
| node_name_to_target_dtype: Dict[str, torch.dtype], |
| qconfig_map: Dict[str, QConfigAny], |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| graph: Graph, |
| ) -> None: |
| """ |
| If the output needs to be quantized and there are any nodes |
| in the output which are not already observed, inserts observers |
| for those nodes. |
| """ |
| |
| # TODO(future PR): update the output_quantized_idxs API to match |
| # arbitrary data structures. There is always a single output, and |
| # that output can have arbitrary nesting of values. List[int] is |
| # not the right data type for this. |
| assert output_quantized_idxs == [0] or output_quantized_idxs == [], \ |
| 'unrecognized format of output_quantized_idxs' |
| |
| # Currently dequants are inserted in the convert step. So, we only |
| # have to do anything if the output is hardcoded to be quantized |
| if output_quantized_idxs == []: |
| return |
| # TODO(future PR): support more dtypes in model outputs, if necessary |
| output_target_dtype = torch.quint8 |
| |
| def _recursive_maybe_replace_node_with_obs( |
| maybe_node: Argument, |
| target_dtype: torch.dtype, |
| node_name_to_target_dtype: Dict[str, torch.dtype], |
| qconfig_map: Dict[str, QConfigAny], |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| graph: Graph, |
| ) -> Argument: |
| """ |
| Navigate an arbitrary data structure of lists, tuples, dicts. |
| For each container type, recurse on all inputs. Once any Node |
| is found, insert an observer if needed and do not recurse further. |
| |
| For example, given a structure of |
| |
| {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}} |
| |
| we recurse down to bar1 and bar3, observe them if necessary, |
| and if we inserted an observer then replace the original node |
| with its observer. |
| |
| Returns the data structure with all nodes needing observation being |
| replaced by their observers. |
| """ |
| if isinstance(maybe_node, Node): |
| # check dtype of this node |
| this_node_dtype = node_name_to_target_dtype[maybe_node.name] |
| if this_node_dtype != target_dtype: |
| # insert observer |
| qconfig = qconfig_map.get(maybe_node.name) |
| # TODO(future PR): see if we need to allow specifying qconfig |
| # on output nodes, to remove the restriction below. |
| assert qconfig is not None, \ |
| 'Quantizing the output node without a qconfig is not supported' |
| observer_mod = qconfig.activation() |
| observer_node = insert_observer( |
| maybe_node, observer_mod, model, modules, graph) |
| return observer_node |
| else: |
| return maybe_node |
| elif isinstance(maybe_node, (list, tuple)): |
| results = [] |
| for inner_node in maybe_node: |
| results.append(_recursive_maybe_replace_node_with_obs( |
| inner_node, target_dtype, node_name_to_target_dtype, |
| qconfig_map, model, modules, graph)) |
| if isinstance(maybe_node, list): |
| return results |
| else: |
| return tuple(results) |
| elif isinstance(maybe_node, dict): |
| results_dict = {} |
| for k, inner_v in maybe_node.items(): |
| results_dict[k] = _recursive_maybe_replace_node_with_obs( |
| inner_v, target_dtype, node_name_to_target_dtype, |
| qconfig_map, model, modules, graph) |
| return results_dict |
| else: |
| return results |
| |
| new_args = [] |
| for old_arg in graph_output_node.args: |
| new_args.append( |
| _recursive_maybe_replace_node_with_obs( |
| old_arg, output_target_dtype, node_name_to_target_dtype, |
| qconfig_map, model, modules, graph)) |
| |
| graph_output_node.args = new_args # type: ignore[assignment] |
| |
| |
| def maybe_propagate_dtype_for_node( |
| node: Node, |
| target_dtype: torch.dtype, |
| node_name_to_target_dtype: Dict[str, torch.dtype], |
| matches: Dict[str, MatchResult], |
| ) -> None: |
| """ |
| Assigns `target_dtype` to `node`. If `node` is a general tensor shape op |
| (see GeneralTensorShapeOpQuantizeHandler in quantization_patterns.py for more details) |
| also call this function recursively on |
| the first argument, to propagate the dtype to the caller. |
| """ |
| node_name_to_target_dtype[node.name] = target_dtype |
| # if this is a copy node, propagate to first arg |
| root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( |
| node.name, (None, None, None, None, None)) |
| if qhandler is not None and qhandler.is_general_tensor_shape_op(): |
| prev_node = node.args[0] |
| if isinstance(prev_node, Node): |
| maybe_propagate_dtype_for_node( |
| prev_node, target_dtype, node_name_to_target_dtype, matches) |
| |
| def propagate_dtypes_for_known_nodes( |
| graph: Graph, |
| node_name_to_target_dtype: Dict[str, torch.dtype], |
| matches: Dict[str, MatchResult], |
| ) -> None: |
| """ |
| Currently we assume that inputs to the graph are either `torch.float` or |
| `torch.quint8`, which is not always correct. For ops such as |
| `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a |
| `BoolTensor`. Propagate this information throughout the graph. |
| |
| Note: not all dtypes in the graph will be correct after this pass, but a |
| higher percentage of them will be correct. Hopefully in the future we can |
| replace this with a better way to reason about dtypes of tensors. |
| """ |
| for node in graph.nodes: |
| bool_arg_idxs = node_bool_tensor_arg_indexes(node) |
| for bool_arg_idx in bool_arg_idxs: |
| cur_node = node.args[bool_arg_idx] |
| maybe_propagate_dtype_for_node( |
| cur_node, torch.bool, node_name_to_target_dtype, matches) |
| |
| def maybe_make_input_output_share_observers( |
| node: Node, |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| ) -> bool: |
| """ |
| Ensures that we share an observer |
| for all input arguments as well as the output argument. In detail, given |
| a graph of |
| |
| x0 -> obs0 -> op -> x2 |
| / |
| x1 -> obs1 / |
| |
| where node obs0 points to observer instance observer0, |
| obs1 points to observer1 and obs2 points to observer2, we make nodes obs1 |
| and ob2 point to observer0. |
| Returns: whether the operation succeeded or not |
| """ |
| first_arg = None |
| # find the first non-Tensor arg |
| for i in range(len(node.args)): |
| if isinstance(node.args[i], (Node, list, tuple)): |
| first_arg = node.args[i] |
| break |
| |
| # if there is no non-Tensor arg, return directly |
| if first_arg is None: |
| return False |
| |
| if isinstance(first_arg, (list, tuple)): |
| first_arg_arg = first_arg[0] |
| elif isinstance(first_arg, Node): |
| first_arg_arg = first_arg |
| else: |
| return False |
| |
| # if we have a graph such as |
| # observed_node -> non_observed_node -> cat |
| # we need to navigate up to the first observer |
| iteration_guard = 0 |
| while not is_activation_post_process_node(first_arg_arg, modules): |
| # did not find an activation_post_process for the op |
| if first_arg_arg.op == "placeholder": |
| return False |
| # trace back the args until we found the first Tensor/Node |
| trace_back_node = None |
| for i in range(len(first_arg_arg.args)): |
| trace_back_node = first_arg_arg.args[i] |
| if isinstance(trace_back_node, Node): |
| break |
| if trace_back_node is None: |
| return False |
| first_arg_arg = trace_back_node |
| |
| iteration_guard += 1 |
| if iteration_guard > 10000: |
| raise AssertionError('Unable to find observer of previous node') |
| |
| assert isinstance(first_arg_arg, Node) |
| target_to_use = first_arg_arg.target |
| assert isinstance(target_to_use, str) |
| obs_mod_to_use = modules[target_to_use] |
| |
| if isinstance(first_arg, (list, tuple)): |
| # set all other input observer nodes to use that module |
| for input_idx, input_arg in enumerate(first_arg): |
| if input_idx == 0: |
| continue |
| iteration_guard = 0 |
| while not is_activation_post_process_node(input_arg, modules): |
| input_arg = input_arg.args[0] |
| iteration_guard += 1 |
| if iteration_guard > 10000: |
| raise AssertionError('Unable to find observer of previous node') |
| |
| parent_name, name = _parent_name(input_arg.target) |
| setattr(modules[parent_name], name, obs_mod_to_use) |
| |
| # set the output observer node to use that module |
| for output_obs_node, _ in node.users.items(): |
| assert is_activation_post_process_node(output_obs_node, modules) |
| parent_name, name = _parent_name(output_obs_node.target) |
| setattr(modules[parent_name], name, obs_mod_to_use) |
| |
| # TODO(future PR): delete the orphaned observer modules |
| return True |
| |
| def remove_output_observer( |
| node: Node, |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module]): |
| items = list(node.users.items()) |
| for output_obs_node, _ in items: |
| assert is_activation_post_process_node(output_obs_node, modules) |
| output_obs_node.replace_all_uses_with(node) |
| model.graph.erase_node(output_obs_node) |
| |
| def swap_custom_module_to_observed( |
| node: Node, |
| qconfig: QConfigAny, |
| modules: Dict[str, torch.nn.Module], |
| prepare_custom_config_dict: Dict[str, Any]): |
| custom_module = modules[node.target] # type: ignore[index] |
| custom_module_class_mapping = prepare_custom_config_dict.get( |
| "float_to_observed_custom_module_class", {}) |
| observed_custom_module_class = \ |
| get_swapped_custom_module_class( |
| custom_module, custom_module_class_mapping, qconfig) |
| observed_custom_module = \ |
| observed_custom_module_class.from_float(custom_module) |
| parent_name, name = _parent_name(node.target) |
| setattr(modules[parent_name], name, observed_custom_module) |
| |
| def insert_observers_for_model( |
| model: GraphModule, |
| modules: Dict[str, torch.nn.Module], |
| matches: Dict[str, MatchResult], |
| qconfig_map: Dict[str, QConfigAny], |
| graph: Graph, |
| prepare_custom_config_dict: Dict[str, Any], |
| equalization_config_map: Dict[str, Any], |
| input_quantized_idxs: List[int], |
| output_quantized_idxs: List[int], |
| ) -> Optional[Node]: |
| """ |
| Inserts observers, using the following high level algorithm: |
| |
| For each node in the graph: |
| 1. determine the target dtype of this node in the quantized graph, and save |
| it for future steps |
| 2. determine the target dtype or all args and kwargs of this node |
| 3. if any arg or kwarg's target dtype does not match the current node's |
| dtype, insert an observer |
| 4. if the current node needs an output observer, insert it |
| |
| For example: |
| |
| - starting graph: |
| x0 -> linear -> x1 |
| |
| - observed graph after processing x0: |
| x0(fp32) |
| |
| - observed graph after processing linear: |
| x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) |
| |
| - observed graph after processing x1: |
| x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 |
| |
| After a node is processed, the naive observer placement is guaranteed to be |
| complete for that node and all of its predecessors. There can be future |
| passes which optimize the graph by deduplicating observers, etc. |
| """ |
| |
| node_name_to_target_dtype: Dict[str, Any] = {} |
| cache_for_no_tensor_check: Dict[Node, bool] = dict() |
| |
| inputs_seen_counter = 0 |
| outputs_seen_counter = 0 |
| results_node = None |
| |
| # first, populate the dtype map based only on qconfig and qhandler |
| # this assumes: |
| # graph inputs are fp32 by default, and int8 where overriden |
| # other nodes output dtype is specified by the qconfig |
| modules = dict(model.named_modules(remove_duplicate=False)) |
| for node in model.graph.nodes: |
| root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( |
| node.name, (None, None, None, None, None)) |
| node_name_to_target_dtype[node.name] = get_target_activation_dtype_for_node( |
| node, qconfig, inputs_seen_counter, outputs_seen_counter, |
| input_quantized_idxs, output_quantized_idxs, qhandler, |
| modules, cache_for_no_tensor_check) |
| |
| # Second, for nodes with known input dtypes, propagate them throughout the |
| # graph. For example, if there is a call such as |
| # x1 = x0.masked_fill(mask, 1) |
| # we propagate the type of mask to be torch.bool |
| propagate_dtypes_for_known_nodes( |
| model.graph, node_name_to_target_dtype, matches) |
| |
| # After this point, the current node and all of its arguments |
| # have a dtype assigned. Now, we insert observers for inputs |
| # of this node (if needed for this node), and the output of this node |
| # (if needed for this node). |
| |
| # Since we are mutating the graph as we go, we iterate over the original |
| # nodes before observer insertion, instead of model.graph.nodes. |
| nodes_before_observation = list(model.graph.nodes) |
| |
| for node in nodes_before_observation: |
| |
| if node.op == 'placeholder': |
| # if a graph input is in fp32, it does not need observation |
| # if a graph input is in int8, we assume the observation happens |
| # outside of the graph, and no additional observation is needed |
| pass |
| |
| elif node.op in ('call_module', 'call_method', 'call_function', 'output'): |
| # check for matches |
| root_node, matched_nodes, pattern, qhandler, qconfig = matches.get( |
| node.name, (None, None, None, None, None)) |
| equalization_qconfig = equalization_config_map.get(node.name, None) |
| |
| this_node_dtype = node_name_to_target_dtype[node.name] |
| output_not_a_tensor = this_node_dtype is None |
| # TODO(future PR): consider stopping matching getitem |
| is_getitem = node.op == 'call_function' and \ |
| node.target == operator.getitem |
| |
| skip_inserting_observers = ( |
| (qconfig is None) or |
| output_not_a_tensor or |
| is_getitem |
| ) and (not node.op == 'output') |
| |
| if not skip_inserting_observers: |
| modules = dict(model.named_modules(remove_duplicate=False)) |
| if node.op != 'output': |
| # this modifies node inplace |
| maybe_insert_input_observers_for_node( |
| node, qconfig, model, modules, graph, |
| node_name_to_target_dtype, |
| qhandler, prepare_custom_config_dict) |
| |
| # Insert equalization input observers if needed |
| maybe_insert_input_equalization_observers_for_node( |
| node, equalization_qconfig, model, modules, graph, |
| node_name_to_target_dtype) |
| |
| is_last_node_of_pattern = root_node is node |
| is_general_tensor_value_op = \ |
| (qhandler is not None and qhandler.is_general_tensor_value_op()) |
| |
| is_general_tensor_shape_op = \ |
| (qhandler is not None and qhandler.is_general_tensor_shape_op()) |
| |
| if is_last_node_of_pattern and not is_general_tensor_shape_op: |
| # this returns the new observer node if it was needed |
| maybe_output_obs_node = maybe_insert_output_observer_for_node( |
| node, model, modules, graph, matches, |
| node_name_to_target_dtype, pattern, qhandler) |
| if maybe_output_obs_node is not None: |
| # Update users of original node to use the output observer |
| # instead. For example, change |
| # |
| # next_node |
| # / |
| # cur_node -> obs |
| # |
| # to |
| # |
| # next_node |
| # / |
| # cur_node -> obs |
| # |
| # We need to save orig users before updating uses because |
| # the list of users will change as we update uses |
| orig_users = list(node.users.keys()) |
| for user_node in orig_users: |
| if user_node is maybe_output_obs_node: |
| continue |
| user_node.replace_input_with(node, maybe_output_obs_node) |
| |
| # for general tensor value ops, we modify the graph |
| # to make all inputs and outputs use the first input's |
| # observer |
| if is_general_tensor_value_op: |
| if not maybe_make_input_output_share_observers(node, model, modules): |
| remove_output_observer(node, model, modules) |
| |
| if isinstance(qhandler, CustomModuleQuantizeHandler): |
| swap_custom_module_to_observed(node, qconfig, modules, prepare_custom_config_dict) |
| |
| else: # output |
| maybe_insert_observers_before_graph_output( |
| node, output_quantized_idxs, |
| node_name_to_target_dtype, qconfig_map, |
| model, modules, graph) |
| |
| # |
| # After this point, the current node has input and output observers |
| # that it needs for itself inserted. |
| # |
| |
| # increment the counters, so future inputs and outputs are assigned |
| # correct dtypes |
| if node.op == 'placeholder': |
| inputs_seen_counter += 1 |
| elif node.op == 'output': |
| outputs_seen_counter += 1 |
| results_node = node |
| |
| return results_node |
| |
| def run_prepare_fx_on_standalone_modules( |
| model: torch.nn.Module, |
| modules: Dict[str, torch.nn.Module], |
| matches: Any, |
| prepare_custom_config_dict: Dict[str, Any], |
| ) -> None: |
| """ |
| Runs prepare_fx on each standalone module. Note: this does |
| not modify the graph, it just replaces the unobserved modules with |
| their observed versions. |
| """ |
| for ( |
| node_name, |
| (root_node, matched_nodes, pattern, qhandler, qconfig), |
| ) in matches.items(): |
| if qhandler is None: |
| continue |
| elif not isinstance(qhandler, StandaloneModuleQuantizeHandler): |
| continue |
| |
| sm_qconfig_dict, sm_prepare_config_dict = \ |
| get_standalone_module_configs( |
| root_node, modules, prepare_custom_config_dict, qconfig) |
| |
| standalone_module = modules[root_node.target] |
| prepare = \ |
| torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined] |
| observed_standalone_module = \ |
| prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict) |
| preserved_attributes = \ |
| set(sm_prepare_config_dict.get("preserved_attributes", [])) |
| observed_standalone_module = ObservedStandaloneGraphModule( |
| observed_standalone_module, observed_standalone_module.graph, |
| preserved_attributes) |
| parent_name, name = _parent_name(root_node.target) |
| setattr(modules[parent_name], name, |
| observed_standalone_module) |
| modules[root_node.target] = observed_standalone_module |
| |
| def save_state( |
| observed: GraphModule, |
| qconfig_map: Dict[str, QConfigAny], |
| node_name_to_scope: Dict[str, Tuple[str, type]], |
| patterns: Dict[Pattern, QuantizeHandler], |
| prepare_custom_config_dict: Dict[str, Any], |
| equalization_qconfig_map: Dict[str, Any], |
| ) -> None: |
| observed._patterns = patterns # type: ignore[assignment] |
| observed._qconfig_map = qconfig_map # type: ignore[assignment] |
| observed._prepare_custom_config_dict = \ |
| prepare_custom_config_dict # type: ignore[assignment] |
| observed._node_name_to_scope = node_name_to_scope # type: ignore[assignment] |
| observed._equalization_qconfig_map = equalization_qconfig_map # type: ignore[assignment] |
| |
| def prepare( |
| model: GraphModule, |
| qconfig_dict: Any, |
| node_name_to_scope: Dict[str, Tuple[str, type]], |
| prepare_custom_config_dict: Optional[Dict[str, Any]] = None, |
| equalization_qconfig_dict: Optional[Dict[str, Any]] = None, |
| is_standalone_module: bool = False) -> ObservedGraphModule: |
| """ standalone_module means it a submodule that is not inlined in |
| parent module, and will be quantized separately as one unit. |
| |
| How the standalone module is observed is specified by `input_quantized_idxs` and |
| `output_quantized_idxs` in the prepare_custom_config for the standalone module |
| Args: |
| node_name_to_scope: mapping from node name to the scope of the module which contains the node. |
| The scope is a tuple of fully qualified path of the module and the type of the module |
| Returns: |
| model(GraphModule): prepared standalone module |
| attributes: |
| _standalone_module_input_quantized_idxs(List[Int]): a list of |
| indexes for the graph input that is expected to be quantized, |
| same as input_quantized_idxs configuration provided |
| for the standalone module |
| _standalone_module_output_quantized_idxs(List[Int]): a list of |
| indexs for the graph output that is quantized |
| same as input_quantized_idxs configuration provided |
| for the standalone module |
| """ |
| if prepare_custom_config_dict is None: |
| prepare_custom_config_dict = {} |
| if equalization_qconfig_dict is None: |
| equalization_qconfig_dict = {} |
| |
| additional_quant_patterns = \ |
| prepare_custom_config_dict.get("additional_quant_pattern", {}) |
| # mapping from a tuple of nodes in reverse order to uninitialized |
| # QuantizeHandler subclass. For example, |
| # { |
| # # match a single node |
| # (<class 'torch.nn.modules.conv.Conv3d'>: |
| # <class 'torch.quantization.fx.quantize.ConvRelu'>), |
| # # match multiple nodes in reverse order |
| # ((<function relu at 0x7f766a7360d0>, <built-in function add>): |
| # <class 'torch.quantization.fx.quantize.Add'>), |
| # } |
| patterns: Dict[Pattern, QuantizeHandler] = get_combined_dict( |
| get_default_quant_patterns(), additional_quant_patterns) |
| |
| convert_dict_to_ordered_dict(qconfig_dict) |
| convert_dict_to_ordered_dict(equalization_qconfig_dict) |
| flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) |
| # TODO: support regex as well |
| propagate_qconfig_(model, flattened_qconfig_dict) |
| |
| if model.training: |
| additional_qat_module_mapping = prepare_custom_config_dict.get( |
| "additional_qat_module_mapping", {}) |
| qat_swap_modules(model, additional_qat_module_mapping) |
| qconfig_dict = update_qconfig_for_qat(qconfig_dict, additional_qat_module_mapping) |
| |
| qconfig_dict = update_qconfig_for_fusion(model, qconfig_dict) |
| equalization_qconfig_dict = update_qconfig_for_fusion(model, equalization_qconfig_dict) |
| |
| # mapping from fully qualified module name to module instance |
| # for example, |
| # { |
| # '': Model(...), |
| # 'linear': Linear(...), |
| # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), |
| # } |
| modules = dict(model.named_modules()) |
| |
| # fill qconfig_map, a map from node name to qconfig, used in find_matches |
| equalization_qconfig_map = generate_qconfig_map(model, modules, model.graph, equalization_qconfig_dict, node_name_to_scope) |
| qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_dict, node_name_to_scope) |
| |
| # match the patterns that will get quantized |
| standalone_module_name_configs = prepare_custom_config_dict.get( |
| "standalone_module_name", []) |
| standalone_module_class_configs = prepare_custom_config_dict.get( |
| "standalone_module_class", []) |
| |
| standalone_module_names = [config[0] for config in standalone_module_name_configs] |
| standalone_module_classes = [config[0] for config in standalone_module_class_configs] |
| custom_module_classes = get_custom_module_class_keys( |
| prepare_custom_config_dict, "float_to_observed_custom_module_class") |
| matches = find_matches( |
| model.graph, modules, patterns, qconfig_map, standalone_module_names, |
| standalone_module_classes, custom_module_classes) |
| |
| input_quantized_idxs: List[int] = prepare_custom_config_dict.get( |
| "input_quantized_idxs", []) |
| output_quantized_idxs: List[int] = prepare_custom_config_dict.get( |
| "output_quantized_idxs", []) |
| |
| run_prepare_fx_on_standalone_modules( |
| model, modules, matches, prepare_custom_config_dict) |
| |
| result_node = insert_observers_for_model( |
| model, modules, matches, qconfig_map, |
| model.graph, prepare_custom_config_dict, |
| equalization_qconfig_map, |
| input_quantized_idxs, output_quantized_idxs) |
| |
| save_state(model, qconfig_map, node_name_to_scope, patterns, |
| prepare_custom_config_dict, equalization_qconfig_map) |
| preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", [])) |
| model = ObservedGraphModule(model, model.graph, preserved_attributes) |
| if is_standalone_module: |
| assert result_node is not None |
| assert isinstance(result_node.args[0], Node), \ |
| "standalone module only supports returning simple value currently"\ |
| "(not tuple, dict etc.)" |
| # these inputs are observed in parent |
| # converting List[int] to Tensor since module attribute is |
| # Union[Tensor, Module] |
| model._standalone_module_input_quantized_idxs = \ |
| torch.tensor(input_quantized_idxs) |
| model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs) |
| return model |