blob: c76365cc75bea33fedfc528d377b00b737fa7619 [file] [log] [blame]
import torch
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.prepare import (
_get_arg_as_input_act_obs_or_fq,
_get_output_act_obs_or_fq,
_get_dtype_and_is_dynamic,
_insert_obs_or_fq,
_maybe_insert_output_observer_for_node,
_save_state,
_is_activation_post_process_node,
_get_qspec_for_arg,
)
from torch.fx import (
GraphModule,
Node,
)
from torch.fx.node import Argument
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from typing import Dict, Tuple, Union, Any
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
EdgeOrNode,
SharedQuantizationSpec,
)
from torch.ao.quantization import ObserverOrFakeQuantize
def _maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
arg: Argument,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
`node` and `arg` if necessary.
"""
# for ops such as torch.cat([x0, x1]),
# traverse through the list
if isinstance(arg, (list, tuple)):
new_arg_to_return = []
for inner_arg in arg:
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
)
new_arg_to_return.append(new_inner_arg)
return type(arg)(new_arg_to_return)
if not isinstance(arg, Node):
return arg
assert isinstance(arg, Node)
# default (no observer)
new_arg = arg
quantization_annotation = node.meta.get("quantization_annotation", QuantizationAnnotation())
arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat)
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
if arg_as_input_target_is_dynamic or arg_as_input_target_dtype not in [torch.float, None]:
if arg_as_input_target_dtype == arg_as_output_target_dtype and \
arg_as_input_target_is_dynamic == arg_as_output_target_is_dynamic:
assert _is_activation_post_process_node(arg, named_modules)
assert arg_as_input_act_obs_or_fq is not None
observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), f"expect observed argument to be a Node, but got: {type(observed_arg)}"
assert observed_arg in obs_or_fq_map, \
f"can't refer to a node that does not have observer/fake_quant inserted yet: {observed_arg}"
input_qspec_map = quantization_annotation.input_qspec_map
input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
if isinstance(input_arg_qspec, SharedQuantizationSpec):
# if the argument is set to use SharedQuantizationSpec, we will
# reset the observer instance to align with the configured edge/node
obs_or_fq_name = arg.target
setattr(model, obs_or_fq_name, arg_as_input_act_obs_or_fq)
named_modules[obs_or_fq_name] = arg_as_input_act_obs_or_fq
else:
# otherwise reuse the existing obs/fq
arg_as_input_act_obs_or_fq = obs_or_fq_map[observed_arg]
# we don't need to insert new observer node
new_arg = arg
obs_or_fq_map[(observed_arg, node)] = arg_as_input_act_obs_or_fq
else:
assert arg_as_input_act_obs_or_fq is not None
new_obs_node = _insert_obs_or_fq(
arg, arg_as_input_act_obs_or_fq, model, named_modules, model.graph) # type: ignore[arg-type]
new_arg = new_obs_node
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
return new_arg
def _maybe_insert_input_observers_for_node(
node: Node,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Note: modifies `node` inplace.
For example, if cur_node needs an observer after prev_node, we change from
prev_node -> cur_node
To
prev_node -> obs -> cur_node
"""
# Look through every input arg. If that arg's target dtype does not
# match the current node's target dtype, insert an observer.
new_args = []
for arg in node.args:
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
)
new_args.append(new_arg)
# Clone has memory_format kwarg that persist in exported graph
# this is just a work around for that.
assert (
node.target == torch.ops.aten.clone.default or len(node.kwargs) == 0
), " expecting kwargs for aten op IR to be empty"
# assign the new args to the node, inplace
node.args = tuple(new_args)
def _maybe_insert_input_and_output_observers_for_node(
node: Node,
model: torch.fx.GraphModule,
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
):
this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None
if "val" in node.meta:
output_is_a_tensor = (
this_node_quantization_annotation is not None and
isinstance(node.meta["val"], FakeTensor)
)
else:
output_is_a_tensor = this_node_quantization_annotation is not None
skip_inserting_input_and_output_observers = (
this_node_quantization_annotation is None
)
if skip_inserting_input_and_output_observers:
return
named_modules = dict(model.named_modules(remove_duplicate=False))
_maybe_insert_input_observers_for_node(
node,
None, # qconfig
model,
named_modules,
obs_or_fq_map,
is_qat,
)
skip_inserting_output_observers = (
not output_is_a_tensor
)
if skip_inserting_output_observers:
return
# this returns the new observer node if it was needed
maybe_output_obs_node = _maybe_insert_output_observer_for_node(node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
if maybe_output_obs_node is None:
return
# Update users of original node to use the output observer
# instead. For example, change
#
# next_node
# /
# cur_node -> obs
#
# to
#
# next_node
# /
# cur_node -> obs
#
# We need to save orig users before updating uses because
# the list of users will change as we update uses
orig_users = list(node.users.keys())
for user_node in orig_users:
if user_node is maybe_output_obs_node:
continue
user_node.replace_input_with(node, maybe_output_obs_node)
def prepare(
model: GraphModule,
node_name_to_scope: Dict[str, Tuple[str, type]],
is_qat: bool,
) -> GraphModule:
# Since we are mutating the graph as we go, we iterate over the original
# nodes before observer insertion, instead of model.graph.nodes.
nodes_before_observation = list(model.graph.nodes)
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
for node in nodes_before_observation:
_maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
model = GraphModule(model, model.graph)
_save_state(
model,
{}, # node_name_to_qconfig
node_name_to_scope,
PrepareCustomConfig(),
{}, # equalization_node_name_to_qconfig
QConfigMapping(),
is_qat,
set() # observed_node_names
)
return model