| import torch |
| from torch._subclasses import FakeTensor |
| from torch.ao.quantization.fx.prepare import ( |
| _get_arg_as_input_act_obs_or_fq, |
| _get_output_act_obs_or_fq, |
| _get_dtype_and_is_dynamic, |
| _insert_obs_or_fq, |
| _maybe_insert_output_observer_for_node, |
| _save_state, |
| _is_activation_post_process_node, |
| _get_qspec_for_arg, |
| ) |
| from torch.fx import ( |
| GraphModule, |
| Node, |
| ) |
| from torch.fx.node import Argument |
| |
| from torch.ao.quantization import QConfigMapping |
| from torch.ao.quantization.qconfig import QConfigAny |
| from torch.ao.quantization.fx.custom_config import PrepareCustomConfig |
| from typing import Dict, Tuple, Union, Any |
| from torch.ao.quantization.quantizer import ( |
| QuantizationAnnotation, |
| EdgeOrNode, |
| SharedQuantizationSpec, |
| ) |
| from torch.ao.quantization import ObserverOrFakeQuantize |
| |
| def _maybe_insert_input_observer_for_arg_or_kwarg( |
| node: Union[Node, Any], |
| arg: Argument, |
| qconfig: QConfigAny, |
| model: torch.nn.Module, |
| named_modules: Dict[str, torch.nn.Module], |
| obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], |
| is_qat: bool, |
| ) -> Argument: |
| """ |
| Given a `node` and an `arg`, inserts an input observer between |
| `node` and `arg` if necessary. |
| """ |
| # for ops such as torch.cat([x0, x1]), |
| # traverse through the list |
| if isinstance(arg, (list, tuple)): |
| new_arg_to_return = [] |
| for inner_arg in arg: |
| new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( |
| node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat, |
| ) |
| new_arg_to_return.append(new_inner_arg) |
| return type(arg)(new_arg_to_return) |
| |
| if not isinstance(arg, Node): |
| return arg |
| assert isinstance(arg, Node) |
| # default (no observer) |
| new_arg = arg |
| |
| quantization_annotation = node.meta.get("quantization_annotation", QuantizationAnnotation()) |
| arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat) |
| arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) |
| |
| arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat) |
| arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) |
| |
| if arg_as_input_target_is_dynamic or arg_as_input_target_dtype not in [torch.float, None]: |
| if arg_as_input_target_dtype == arg_as_output_target_dtype and \ |
| arg_as_input_target_is_dynamic == arg_as_output_target_is_dynamic: |
| assert _is_activation_post_process_node(arg, named_modules) |
| assert arg_as_input_act_obs_or_fq is not None |
| observed_arg = arg.args[0] |
| assert isinstance(observed_arg, Node), f"expect observed argument to be a Node, but got: {type(observed_arg)}" |
| assert observed_arg in obs_or_fq_map, \ |
| f"can't refer to a node that does not have observer/fake_quant inserted yet: {observed_arg}" |
| input_qspec_map = quantization_annotation.input_qspec_map |
| input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules) |
| if isinstance(input_arg_qspec, SharedQuantizationSpec): |
| # if the argument is set to use SharedQuantizationSpec, we will |
| # reset the observer instance to align with the configured edge/node |
| obs_or_fq_name = arg.target |
| setattr(model, obs_or_fq_name, arg_as_input_act_obs_or_fq) |
| named_modules[obs_or_fq_name] = arg_as_input_act_obs_or_fq |
| else: |
| # otherwise reuse the existing obs/fq |
| arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg] |
| # we don't need to insert new observer node |
| new_arg = arg |
| obs_or_fq_map[(observed_arg, node)] = arg_as_input_act_obs_or_fq |
| else: |
| assert arg_as_input_act_obs_or_fq is not None |
| new_obs_node = _insert_obs_or_fq( |
| arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph) # type: ignore[arg-type] |
| new_arg = new_obs_node |
| obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq |
| |
| return new_arg |
| |
| def _maybe_insert_input_observers_for_node( |
| node: Node, |
| qconfig: QConfigAny, |
| model: torch.nn.Module, |
| named_modules: Dict[str, torch.nn.Module], |
| obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], |
| is_qat: bool, |
| ) -> None: |
| """ |
| If needed, inserts observers to the input args and kwargs of `node`. |
| Note: modifies `node` inplace. |
| |
| For example, if cur_node needs an observer after prev_node, we change from |
| |
| prev_node -> cur_node |
| |
| To |
| |
| prev_node -> obs -> cur_node |
| |
| """ |
| # Look through every input arg. If that arg's target dtype does not |
| # match the current node's target dtype, insert an observer. |
| new_args = [] |
| for arg in node.args: |
| new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( |
| node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat, |
| ) |
| new_args.append(new_arg) |
| |
| # Clone has memory_format kwarg that persist in exported graph |
| # this is just a work around for that. |
| assert ( |
| node.target == torch.ops.aten.clone.default or len(node.kwargs) == 0 |
| ), " expecting kwargs for aten op IR to be empty" |
| |
| # assign the new args to the node, inplace |
| node.args = tuple(new_args) |
| |
| def _maybe_insert_input_and_output_observers_for_node( |
| node: Node, |
| model: torch.fx.GraphModule, |
| obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], |
| is_qat: bool, |
| ): |
| this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None |
| if "val" in node.meta: |
| output_is_a_tensor = ( |
| this_node_quantization_annotation is not None and |
| isinstance(node.meta["val"], FakeTensor) |
| ) |
| else: |
| output_is_a_tensor = this_node_quantization_annotation is not None |
| |
| skip_inserting_input_and_output_observers = ( |
| this_node_quantization_annotation is None |
| ) |
| |
| if skip_inserting_input_and_output_observers: |
| return |
| |
| named_modules = dict(model.named_modules(remove_duplicate=False)) |
| |
| _maybe_insert_input_observers_for_node( |
| node, |
| None, # qconfig |
| model, |
| named_modules, |
| obs_or_fq_map, |
| is_qat, |
| ) |
| |
| skip_inserting_output_observers = ( |
| not output_is_a_tensor |
| ) |
| |
| if skip_inserting_output_observers: |
| return |
| |
| # this returns the new observer node if it was needed |
| maybe_output_obs_node = _maybe_insert_output_observer_for_node(node, model, named_modules, model.graph, obs_or_fq_map, is_qat) |
| |
| if maybe_output_obs_node is None: |
| return |
| # Update users of original node to use the output observer |
| # instead. For example, change |
| # |
| # next_node |
| # / |
| # cur_node -> obs |
| # |
| # to |
| # |
| # next_node |
| # / |
| # cur_node -> obs |
| # |
| # We need to save orig users before updating uses because |
| # the list of users will change as we update uses |
| orig_users = list(node.users.keys()) |
| for user_node in orig_users: |
| if user_node is maybe_output_obs_node: |
| continue |
| user_node.replace_input_with(node, maybe_output_obs_node) |
| |
| def prepare( |
| model: GraphModule, |
| node_name_to_scope: Dict[str, Tuple[str, type]], |
| is_qat: bool, |
| ) -> GraphModule: |
| # Since we are mutating the graph as we go, we iterate over the original |
| # nodes before observer insertion, instead of model.graph.nodes. |
| nodes_before_observation = list(model.graph.nodes) |
| obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {} |
| |
| for node in nodes_before_observation: |
| _maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat) |
| |
| model = GraphModule(model, model.graph) |
| |
| _save_state( |
| model, |
| {}, # node_name_to_qconfig |
| node_name_to_scope, |
| PrepareCustomConfig(), |
| {}, # equalization_node_name_to_qconfig |
| QConfigMapping(), |
| is_qat, |
| set() # observed_node_names |
| ) |
| return model |