blob: 5e2da85a1c99831235404991e29a5663a9ad60e6 [file] [log] [blame]
import torch
from torch.fx import (
GraphModule,
Proxy,
map_arg
)
from torch.fx.graph import (
Graph,
Node,
)
from torch.fx.node import Argument
from torch.quantization import (
propagate_qconfig_,
convert,
)
from ..quantization_mappings import (
get_default_qat_module_mappings,
)
from ..quantize import (
_remove_qconfig,
is_activation_post_process
)
from ..utils import (
get_combined_dict,
get_qconfig_dtypes,
weight_is_quantized,
activation_is_statically_quantized,
activation_is_int8_quantized,
activation_dtype,
weight_dtype,
)
from .pattern_utils import (
is_match,
get_default_quant_patterns,
get_default_output_activation_post_process_map,
Pattern,
)
from .graph_module import (
is_observed_module,
is_observed_standalone_module,
ObservedGraphModule,
ObservedStandaloneGraphModule,
QuantizedGraphModule,
)
from .quantization_patterns import (
binary_op_supported_dtypes,
binary_reference_op_supported_dtypes,
BinaryOpQuantizeHandler,
CatQuantizeHandler,
CopyNodeQuantizeHandler,
CustomModuleQuantizeHandler,
QuantizeHandler,
StandaloneModuleQuantizeHandler,
)
from .utils import (
_parent_name,
all_node_args_have_no_tensors,
is_get_tensor_info_node,
quantize_node,
get_custom_module_class_keys,
get_new_attr_name_with_prefix,
collect_producer_nodes,
graph_module_from_producer_nodes,
assert_and_get_unique_device,
node_return_type_is_int,
node_bool_tensor_arg_indexes,
)
from .qconfig_utils import (
convert_dict_to_ordered_dict,
generate_qconfig_map,
get_flattened_qconfig_dict,
QConfigAny,
)
import operator
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
# Define helper types
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# ------------------------
# Helper Functions
# ------------------------
def qat_swap_modules(
root: torch.nn.Module,
additional_qat_module_mapping: Dict[Callable, Callable]) -> None:
all_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
def get_standalone_module_configs(
node: Node,
modules: Dict[str, torch.nn.Module],
prepare_custom_config_dict: Dict[str, Any],
qconfig: QConfigAny,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Returns the standalone module qconfig_dict and prepare_config_dict
for `node`, assuming that the module pointed to by `node` is
a standalone modules.
"""
standalone_module = modules[node.target] # type: ignore[index]
standalone_module_name_configs = \
prepare_custom_config_dict.get("standalone_module_name", [])
standalone_module_class_configs = \
prepare_custom_config_dict.get("standalone_module_class", [])
class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs}
name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs}
config = class_config_map.get(type(standalone_module), (None, None))
config = name_config_map.get(node.target, config)
sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0]
sm_prepare_config_dict = {} if config[1] is None else config[1]
return sm_qconfig_dict, sm_prepare_config_dict
def insert_observer(
node: Node,
observer: torch.quantization.ObserverBase,
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
) -> Node:
"""
Attaches `observer` to `model`, and creates a node which calls
`observer` on the output of `node`.
"""
model_device = assert_and_get_unique_device(model)
if model_device:
observer.to(model_device)
# add observer module as attribute
prefix = node.name + '_activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
modules[observer_name] = observer
with graph.inserting_after(node):
new_obs = graph.create_node(
'call_module', observer_name, (node,), {})
return new_obs
def get_target_activation_dtype_for_node(
node: Node,
qconfig: QConfigAny,
inputs_seen_counter: int,
outputs_seen_counter: int,
input_quantized_idxs: List[int],
output_quantized_idxs: List[int],
qhandler: Optional[QuantizeHandler],
modules: Dict[str, torch.nn.Module],
cache_for_no_tensor_check: Dict[Node, bool],
) -> Optional[torch.dtype]:
"""
Returns the expected dtype of the input and output of this node after
convert. If the value is not None, it represents the dtype of the
Tensor. If the value is None, it means the value is not a Tensor.
Note: this is for activations only, weight dtypes are not handled here.
TODO(future PR, if needed): explicitly spell out the non-Tensor
dtypes.
"""
if node.op == 'placeholder':
if inputs_seen_counter in input_quantized_idxs:
return torch.quint8
else:
# if dtype is fp32 (default), do nothing
# note: other dtypes are not supported
return torch.float
elif node.op in ('call_module', 'call_method', 'call_function'):
args_have_no_tensors = \
all_node_args_have_no_tensors(
node, modules, cache_for_no_tensor_check)
if args_have_no_tensors:
return None
# TODO(future PR): consider stopping matching getitem
is_getitem = node.op == 'call_function' and \
node.target == operator.getitem
if is_getitem:
return torch.float
# get qconfig to determine the eventual dtype of this node
if qconfig is not None:
if qhandler is not None and qhandler.input_output_observed():
act_dtype, weight_dtype, act_compute_dtype = \
get_qconfig_dtypes(qconfig)
return act_dtype
else:
return torch.float
else:
return torch.float
elif node.op == 'get_attr':
return torch.float
elif node.op == 'output':
if outputs_seen_counter in output_quantized_idxs:
return torch.quint8
else:
# if dtype is fp32 (default), do nothing
# note: other dtypes are not supported
return torch.float
else:
raise AssertionError(f'need to handle {node.format_node()}')
def maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
arg: Argument,
qconfig: QConfigAny,
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_target_dtype: Dict[str, Any],
qhandler: Optional[QuantizeHandler],
prepare_custom_config_dict: Dict[str, Any],
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
`node` and `arg` if necessary.
"""
# for ops such as torch.cat([x0, x1]),
# traverse through the list
if isinstance(arg, (list, tuple)):
new_arg_to_return = []
for inner_arg in arg:
new_inner_arg = maybe_insert_input_observer_for_arg_or_kwarg(
node, inner_arg, qconfig, model, modules,
graph, node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
new_arg_to_return.append(new_inner_arg)
return new_arg_to_return
if not isinstance(arg, Node):
return arg
assert isinstance(arg, Node)
# default (no observer)
new_arg = arg
is_standalone_module = qhandler is not None and \
isinstance(qhandler, StandaloneModuleQuantizeHandler)
if not is_standalone_module:
# regular flow for most nodes, except standalone modules
is_weight = node_arg_is_weight(node, arg)
assert qconfig is not None
act_post_process_ctr = qconfig.weight if is_weight else \
qconfig.activation
is_bias = node_arg_is_bias(node, arg)
is_activation = not (is_weight or is_bias)
weight_needs_obs = is_weight and weight_is_quantized(qconfig)
bias_needs_obs = \
(is_bias and activation_dtype(qconfig) == torch.float16) and \
weight_dtype(qconfig) == torch.float16
arg_dtype = node_name_to_target_dtype[arg.name]
node_dtype = node_name_to_target_dtype[node.name]
dtype_changes_and_second_dtype_not_float = (
# if the dtypes are different, we need an observer
(arg_dtype != node_dtype) and
# except if the second dtype is float, a dequant will be inserted
# without an observer in convert
# TODO(future PR): change this so a placeholder is inserted for
# future dequants, to make the logic easier to understand
(node_dtype != torch.float) and
# if arg is a bool tensor or not a tensor, do not insert observer
(arg_dtype not in (torch.bool, None)) and
(is_activation and activation_is_statically_quantized(qconfig))
)
needs_obs = (
weight_needs_obs or
bias_needs_obs or
dtype_changes_and_second_dtype_not_float
)
else:
# custom flow for standalone modules
_sm_qconfig_dict, sm_prepare_config_dict = \
get_standalone_module_configs(
node, modules, prepare_custom_config_dict, qconfig)
sm_input_quantized_idxs = \
sm_prepare_config_dict.get('input_quantized_idxs', [])
# for args, this is set to the index of the current arg
# for kwargs, this is left at None
cur_input_idx = None
for arg_idx, arg_to_check in enumerate(node.args):
if arg_to_check is arg:
cur_input_idx = arg_idx
break
if cur_input_idx is None:
needs_obs = False
else:
arg_dtype = node_name_to_target_dtype[arg.name]
node_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \
else torch.float
needs_obs = (
(arg_dtype != node_dtype) and
(node_dtype != torch.float)
)
if needs_obs:
new_obs_mod = act_post_process_ctr()
existing_obs_node = None
# Before using the new observer, check if an observer
# of the correct type already exists. If it does, use it.
# This prevents duplicate observer insertions if a node is
# used by multiple nodes.
for maybe_obs_node, _ in arg.users.items():
if maybe_obs_node.op == 'call_module':
maybe_obs_mod = modules[maybe_obs_node.target] # type: ignore[index]
if (
type(maybe_obs_mod) == type(new_obs_mod) and
node_name_to_target_dtype[maybe_obs_node.name] == node_dtype
):
existing_obs_node = maybe_obs_node
break
if existing_obs_node is None:
new_obs_node = insert_observer(
arg, new_obs_mod, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs_node.name] = node_dtype
# override this arg to be the observed arg
new_arg = new_obs_node
else:
new_arg = existing_obs_node
return new_arg
def maybe_insert_input_observers_for_node(
node: Node,
qconfig: QConfigAny,
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
node_name_to_target_dtype: Dict[str, Any],
qhandler: Optional[QuantizeHandler],
prepare_custom_config_dict: Dict[str, Any],
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Note: modifies `node` inplace.
For example, if cur_node needs an observer after prev_node, we change from
prev_node -> cur_node
To
prev_node -> obs -> cur_node
"""
if qconfig is None:
# if quantization is turned off for this node, we do not need
# to insert input observers
return
assert qconfig is not None
# Look through every input arg. If that arg's target dtype does not
# match the current node's target dtype, insert an observer.
new_args = []
for arg in node.args:
new_arg = maybe_insert_input_observer_for_arg_or_kwarg(
node, arg, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
new_args.append(new_arg)
new_kwargs = {}
for k, kwarg in node.kwargs.items():
new_kwarg = maybe_insert_input_observer_for_arg_or_kwarg(
node, kwarg, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
new_kwargs[k] = new_kwarg
# assign the new args and kwargs to the node, inplace
node.args = tuple(new_args)
node.kwargs = new_kwargs # type: ignore[assignment]
def maybe_insert_output_observer_for_node(
node: Node,
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
matches: Dict[str, MatchResult],
node_name_to_target_dtype: Dict[str, Any],
matched_pattern: Any,
qhandler: Optional[QuantizeHandler],
) -> Optional[Node]:
"""
If `node` needs an output observer, creates it, inserts it into `graph`
and returns it.
If `node` does not need an output observer, returns None.
"""
root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
if qhandler is not None:
assert qconfig is not None
is_standalone_module = qhandler is not None and \
isinstance(qhandler, StandaloneModuleQuantizeHandler)
should_insert_observer = \
qhandler.should_insert_observer_for_output(
qconfig, model.training)
# TODO(future PR): move the following logic to
# should_insert_observer_for_output
should_insert_observer = should_insert_observer and \
activation_is_statically_quantized(qconfig)
# we never insert observers to output of standalone module, we assume
# if needed, they are inserted inside the standalone module
should_insert_observer = should_insert_observer and \
(not is_standalone_module)
if should_insert_observer:
act_post_process_ctr = qconfig.activation
if activation_is_int8_quantized(qconfig):
act_post_process_ctr = \
get_default_output_activation_post_process_map().get(
matched_pattern,
act_post_process_ctr)
observer = act_post_process_ctr()
new_obs = insert_observer(node, observer, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs.name] = \
node_name_to_target_dtype[node.name]
return new_obs
elif node.op == 'output':
prev_node = node.args[0]
assert isinstance(prev_node, Node)
prev_node_dtype = node_name_to_target_dtype[prev_node.name]
node_dtype = node_name_to_target_dtype[node.name]
should_insert_observer = (
prev_node_dtype == torch.float and
node_dtype != torch.float
)
if should_insert_observer:
assert qconfig is not None
observer = qconfig.activation()
new_obs = insert_observer(
prev_node, observer, model, modules, graph)
# set the type, so the next node can read it
node_name_to_target_dtype[new_obs.name] = node_dtype
return new_obs
return None
def maybe_insert_observers_before_graph_output(
graph_output_node: Node,
output_quantized_idxs: List[int],
node_name_to_target_dtype: Dict[str, torch.dtype],
qconfig_map: Dict[str, QConfigAny],
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
) -> None:
"""
If the output needs to be quantized and there are any nodes
in the output which are not already observed, inserts observers
for those nodes.
"""
# TODO(future PR): update the output_quantized_idxs API to match
# arbitrary data structures. There is always a single output, and
# that output can have arbitrary nesting of values. List[int] is
# not the right data type for this.
assert output_quantized_idxs == [0] or output_quantized_idxs == [], \
'unrecognized format of output_quantized_idxs'
# Currently dequants are inserted in the convert step. So, we only
# have to do anything if the output is hardcoded to be quantized
if output_quantized_idxs == []:
return
# TODO(future PR): support more dtypes in model outputs, if necessary
output_target_dtype = torch.quint8
def _recursive_maybe_replace_node_with_obs(
maybe_node: Argument,
target_dtype: torch.dtype,
node_name_to_target_dtype: Dict[str, torch.dtype],
qconfig_map: Dict[str, QConfigAny],
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
graph: Graph,
) -> Argument:
"""
Navigate an arbitrary data structure of lists, tuples, dicts.
For each container type, recurse on all inputs. Once any Node
is found, insert an observer if needed and do not recurse further.
For example, given a structure of
{'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
we recurse down to bar1 and bar3, observe them if necessary,
and if we inserted an observer then replace the original node
with its observer.
Returns the data structure with all nodes needing observation being
replaced by their observers.
"""
if isinstance(maybe_node, Node):
# check dtype of this node
this_node_dtype = node_name_to_target_dtype[maybe_node.name]
if this_node_dtype != target_dtype:
# insert observer
qconfig = qconfig_map.get(maybe_node.name)
# TODO(future PR): see if we need to allow specifying qconfig
# on output nodes, to remove the restriction below.
assert qconfig is not None, \
'Quantizing the output node without a qconfig is not supported'
observer_mod = qconfig.activation()
observer_node = insert_observer(
maybe_node, observer_mod, model, modules, graph)
return observer_node
else:
return maybe_node
elif isinstance(maybe_node, (list, tuple)):
results = []
for inner_node in maybe_node:
results.append(_recursive_maybe_replace_node_with_obs(
inner_node, target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph))
if isinstance(maybe_node, list):
return results
else:
return tuple(results)
elif isinstance(maybe_node, dict):
results_dict = {}
for k, inner_v in maybe_node.items():
results_dict[k] = _recursive_maybe_replace_node_with_obs(
inner_v, target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph)
return results_dict
else:
return results
new_args = []
for old_arg in graph_output_node.args:
new_args.append(
_recursive_maybe_replace_node_with_obs(
old_arg, output_target_dtype, node_name_to_target_dtype,
qconfig_map, model, modules, graph))
graph_output_node.args = new_args # type: ignore[assignment]
def maybe_propagate_dtype_for_node(
node: Node,
target_dtype: torch.dtype,
node_name_to_target_dtype: Dict[str, torch.dtype],
matches: Dict[str, MatchResult],
) -> None:
"""
Assigns `target_dtype` to `node`. If `node` is matched to an instance
of `CopyNodeQuantizeHandler`, also call this function recursively on
the first argument, to propagate the dtype to the caller.
"""
node_name_to_target_dtype[node.name] = target_dtype
# if this is a copy node, propagate to first arg
root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
if isinstance(qhandler, CopyNodeQuantizeHandler):
prev_node = node.args[0]
if isinstance(prev_node, Node):
maybe_propagate_dtype_for_node(
prev_node, target_dtype, node_name_to_target_dtype, matches)
def propagate_dtypes_for_known_nodes(
graph: Graph,
node_name_to_target_dtype: Dict[str, torch.dtype],
matches: Dict[str, MatchResult],
) -> None:
"""
Currently we assume that inputs to the graph are either `torch.float` or
`torch.quint8`, which is not always correct. For ops such as
`x.masked_fill(mask, value)`, we know that the dtype of `mask` is a
`BoolTensor`. Propagate this information throughout the graph.
Note: not all dtypes in the graph will be correct after this pass, but a
higher percentage of them will be correct. Hopefully in the future we can
replace this with a better way to reason about dtypes of tensors.
"""
for node in graph.nodes:
bool_arg_idxs = node_bool_tensor_arg_indexes(node)
for bool_arg_idx in bool_arg_idxs:
cur_node = node.args[bool_arg_idx]
maybe_propagate_dtype_for_node(
cur_node, torch.bool, node_name_to_target_dtype, matches)
def adjust_observers_for_cat(
node: Node,
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
) -> None:
"""
Ensures that for quantized `torch.cat` nodes, we share an observer
for all input arguments as well as the output argument. In detail, given
a graph of
x0 -> obs0 -> cat -> x2
/
x1 -> obs1 /
where node obs0 points to observer instance observer0,
obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
and ob2 point to observer0.
"""
# find the observer module to use
first_arg = node.args[0]
assert isinstance(first_arg, (list, tuple))
first_arg_arg = first_arg[0]
# if we have a graph such as
# observed_node -> non_observed_node -> cat
# we need to navigate up to the first observer
iteration_guard = 0
while not is_activation_post_process_node(first_arg_arg, modules):
first_arg_arg = first_arg_arg.args[0]
iteration_guard += 1
if iteration_guard > 10000:
raise AssertionError('Unable to find observer of previous node')
assert isinstance(first_arg_arg, Node)
target_to_use = first_arg_arg.target
assert isinstance(target_to_use, str)
obs_mod_to_use = modules[target_to_use]
# set all other input observer nodes to use that module
for input_idx, input_arg in enumerate(first_arg):
if input_idx == 0:
continue
iteration_guard = 0
while not is_activation_post_process_node(input_arg, modules):
input_arg = input_arg.args[0]
iteration_guard += 1
if iteration_guard > 10000:
raise AssertionError('Unable to find observer of previous node')
parent_name, name = _parent_name(input_arg.target)
setattr(modules[parent_name], name, obs_mod_to_use)
# set the output observer node to use that module
for output_obs_node, _ in node.users.items():
assert is_activation_post_process_node(output_obs_node, modules)
parent_name, name = _parent_name(output_obs_node.target)
setattr(modules[parent_name], name, obs_mod_to_use)
# TODO(future PR): delete the orphaned observer modules
def insert_observers_for_model(
model: GraphModule,
modules: Dict[str, torch.nn.Module],
matches: Dict[str, MatchResult],
qconfig_map: Dict[str, QConfigAny],
graph: Graph,
prepare_custom_config_dict: Dict[str, Any],
input_quantized_idxs: List[int],
output_quantized_idxs: List[int],
) -> Optional[Node]:
"""
Inserts observers, using the following high level algorithm:
For each node in the graph:
1. determine the target dtype of this node in the quantized graph, and save
it for future steps
2. determine the target dtype or all args and kwargs of this node
3. if any arg or kwarg's target dtype does not match the current node's
dtype, insert an observer
4. if the current node needs an output observer, insert it
For example:
- starting graph:
x0 -> linear -> x1
- observed graph after processing x0:
x0(fp32)
- observed graph after processing linear:
x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
- observed graph after processing x1:
x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
After a node is processed, the naive observer placement is guaranteed to be
complete for that node and all of its predecessors. There can be future
passes which optimize the graph by deduplicating observers, etc.
"""
node_name_to_target_dtype: Dict[str, Any] = {}
cache_for_no_tensor_check: Dict[Node, bool] = dict()
inputs_seen_counter = 0
outputs_seen_counter = 0
results_node = None
# first, populate the dtype map based only on qconfig and qhandler
# this assumes:
# graph inputs are fp32 by default, and int8 where overriden
# other nodes output dtype is specified by the qconfig
modules = dict(model.named_modules(remove_duplicate=False))
for node in model.graph.nodes:
root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
node_name_to_target_dtype[node.name] = get_target_activation_dtype_for_node(
node, qconfig, inputs_seen_counter, outputs_seen_counter,
input_quantized_idxs, output_quantized_idxs, qhandler,
modules, cache_for_no_tensor_check)
# Second, for nodes with known input dtypes, propagate them throughout the
# graph. For example, if there is a call such as
# x1 = x0.masked_fill(mask, 1)
# we propagate the type of mask to be torch.bool
propagate_dtypes_for_known_nodes(
model.graph, node_name_to_target_dtype, matches)
# After this point, the current node and all of its arguments
# have a dtype assigned. Now, we insert observers for inputs
# of this node (if needed for this node), and the output of this node
# (if needed for this node).
# Since we are mutating the graph as we go, we iterate over the original
# nodes before observer insertion, instead of model.graph.nodes.
nodes_before_observation = list(model.graph.nodes)
for node in nodes_before_observation:
# check for matches
root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
if node.op == 'placeholder':
# if a graph input is in fp32, it does not need observation
# if a graph input is in int8, we assume the observation happens
# outside of the graph, and no additional observation is needed
pass
elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
modules = dict(model.named_modules(remove_duplicate=False))
this_node_dtype = node_name_to_target_dtype[node.name]
output_not_a_tensor = this_node_dtype is None
# TODO(future PR): consider stopping matching getitem
is_getitem = node.op == 'call_function' and \
node.target == operator.getitem
skip_inserting_observers = (
(qconfig is None) or
output_not_a_tensor or
is_getitem
) and (not node.op == 'output')
if not skip_inserting_observers:
if node.op != 'output':
# this modifies node inplace
maybe_insert_input_observers_for_node(
node, qconfig, model, modules, graph,
node_name_to_target_dtype,
qhandler, prepare_custom_config_dict)
is_last_node_of_pattern = root_node is node
is_like_copy_node = \
(qhandler is not None and (
isinstance(qhandler, CopyNodeQuantizeHandler)
))
if is_last_node_of_pattern and (not is_like_copy_node):
# this returns the new observer node if it was needed
maybe_output_obs_node = maybe_insert_output_observer_for_node(
node, model, modules, graph, matches,
node_name_to_target_dtype, pattern, qhandler)
if maybe_output_obs_node is not None:
# Update users of original node to use the output observer
# instead. For example, change
#
# next_node
# /
# cur_node -> obs
#
# to
#
# next_node
# /
# cur_node -> obs
#
# We need to save orig users before updating uses because
# the list of users will change as we update uses
orig_users = list(node.users.keys())
for user_node in orig_users:
if user_node is maybe_output_obs_node:
continue
user_node.replace_input_with(node, maybe_output_obs_node)
# for quantized cat nodes only, we modify the graph
# to make all inputs and outputs use the first input's
# observer
if isinstance(qhandler, CatQuantizeHandler):
adjust_observers_for_cat(node, model, modules)
else: # output
maybe_insert_observers_before_graph_output(
node, output_quantized_idxs,
node_name_to_target_dtype, qconfig_map,
model, modules, graph)
#
# After this point, the current node has input and output observers
# that it needs for itself inserted.
#
# increment the counters, so future inputs and outputs are assigned
# correct dtypes
if node.op == 'placeholder':
inputs_seen_counter += 1
elif node.op == 'output':
outputs_seen_counter += 1
results_node = node
return results_node
def run_prepare_fx_on_standalone_modules(
model: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
matches: Any,
prepare_custom_config_dict: Dict[str, Any],
) -> None:
"""
Runs prepare_fx on each standalone module. Note: this does
not modify the graph, it just replaces the unobserved modules with
their observed versions.
"""
for (
node_name,
(root_node, matched_nodes, pattern, qhandler, qconfig),
) in matches.items():
if qhandler is None:
continue
elif not isinstance(qhandler, StandaloneModuleQuantizeHandler):
continue
sm_qconfig_dict, sm_prepare_config_dict = \
get_standalone_module_configs(
root_node, modules, prepare_custom_config_dict, qconfig)
standalone_module = modules[root_node.target] # type: ignore[index]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
observed_standalone_module = \
prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict)
preserved_attributes = \
set(sm_prepare_config_dict.get("preserved_attributes", []))
observed_standalone_module = ObservedStandaloneGraphModule(
observed_standalone_module, observed_standalone_module.graph,
preserved_attributes)
parent_name, name = _parent_name(root_node.target)
setattr(modules[parent_name], name,
observed_standalone_module)
modules[root_node.target] = observed_standalone_module # type: ignore[index]
def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
return node.op == "call_module" and \
is_activation_post_process(modules[str(node.target)])
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv1d : [1],
torch.nn.functional.conv2d : [1],
torch.nn.functional.conv3d : [1],
torch.nn.functional.linear : [1],
}
def node_arg_is_weight(node: Node, arg: Any) -> bool:
if isinstance(node, Node) and node.op == 'call_function' and \
node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in \
WEIGHT_INDEX_DICT[node.target]: # type: ignore[index]
return True
return False
CONV_OPS_WITH_BIAS = {
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
}
CONV_BIAS_ARG_INDEX = 2
def node_arg_is_bias(node: Node, arg: Any) -> bool:
if isinstance(node, Node) and node.op == 'call_function':
if node.target in CONV_OPS_WITH_BIAS:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i == CONV_BIAS_ARG_INDEX:
return True
elif node.target is torch.nn.functional.linear:
for kwarg_name, kwarg_value in node.kwargs.items():
if kwarg_name == 'bias' and arg is kwarg_value:
return True
return False
# weight prepacking ops
WEIGHT_PREPACK_OPS = {
torch._ops.ops.quantized.linear_prepack,
torch._ops.ops.quantized.linear_prepack_fp16,
torch._ops.ops.quantized.conv1d_prepack,
torch._ops.ops.quantized.conv2d_prepack,
torch._ops.ops.quantized.conv3d_prepack,
}
def run_weight_observers(observed: GraphModule) -> None:
r''' Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
Note that the observers of dynamic quant or weight only quant ops are
run during the convert step.
'''
for node in observed.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if i in WEIGHT_INDEX_DICT[node.target]:
# node_arg is weight
weight_observer_nodes = collect_producer_nodes(node_arg)
if weight_observer_nodes is not None:
weight_observer_module = \
graph_module_from_producer_nodes(
observed, weight_observer_nodes)
# run the weight observer
weight_observer_module()
class Quantizer:
def __init__(self):
# mapping from a tuple of nodes in reverse order to uninitialized
# QuantizeHandler subclass. For example,
# {
# # match a single node
# (<class 'torch.nn.modules.conv.Conv3d'>:
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
# # match multiple nodes in reverse order
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.quantization.fx.quantize.Add'>),
# }
self.patterns: Dict[Pattern, QuantizeHandler] = {}
self.prepare_custom_config_dict: Dict[str, Any] = {}
# mapping from node name to the scope of the module which contains the node.
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
def _prepare(
self,
model: GraphModule,
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, type]],
prepare_custom_config_dict: Optional[Dict[str, Any]],
is_standalone_module: bool) -> ObservedGraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
How the standalone module is observed is specified by `input_quantized_idxs` and
`output_quantized_idxs` in the prepare_custom_config for the standalone module
Returns:
model(GraphModule): prepared standalone module
attributes:
_standalone_module_input_quantized_idxs(List[Int]): a list of
indexes for the graph input that is expected to be quantized,
same as input_quantized_idxs configuration provided
for the standalone module
_standalone_module_output_quantized_idxs(List[Int]): a list of
indexs for the graph output that is quantized
same as input_quantized_idxs configuration provided
for the standalone module
"""
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
self.prepare_custom_config_dict = prepare_custom_config_dict
additional_quant_patterns = \
prepare_custom_config_dict.get("additional_quant_pattern", {})
self.patterns = get_combined_dict(
get_default_quant_patterns(), additional_quant_patterns)
convert_dict_to_ordered_dict(qconfig_dict)
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict)
if model.training:
additional_qat_module_mapping = prepare_custom_config_dict.get(
"additional_qat_module_mapping", {})
qat_swap_modules(model, additional_qat_module_mapping)
# mapping from fully qualified module name to module instance
# for example,
# {
# '': Model(...),
# 'linear': Linear(...),
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
# }
modules = dict(model.named_modules())
self.node_name_to_scope = node_name_to_scope
# fill qconfig_map, a map from node name to qconfig, used in _find_matches
qconfig_map = generate_qconfig_map(model, modules, model.graph, qconfig_dict, node_name_to_scope)
# match the patterns that will get quantized
standalone_module_name_configs = prepare_custom_config_dict.get(
"standalone_module_name", [])
standalone_module_class_configs = prepare_custom_config_dict.get(
"standalone_module_class", [])
standalone_module_names = [config[0] for config in standalone_module_name_configs]
standalone_module_classes = [config[0] for config in standalone_module_class_configs]
custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
matches = self._find_matches(
model.graph, modules, self.patterns, qconfig_map, standalone_module_names,
standalone_module_classes, custom_module_classes)
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
"input_quantized_idxs", [])
output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
"output_quantized_idxs", [])
run_prepare_fx_on_standalone_modules(
model, modules, matches, prepare_custom_config_dict)
result_node = insert_observers_for_model(
model, modules, matches, qconfig_map,
model.graph, prepare_custom_config_dict,
input_quantized_idxs, output_quantized_idxs)
self.save_state(model, qconfig_map)
preserved_attributes = set(prepare_custom_config_dict.get("preserved_attributes", []))
model = ObservedGraphModule(model, model.graph, preserved_attributes)
if is_standalone_module:
assert result_node is not None
assert isinstance(result_node.args[0], Node), \
"standalone module only supports returning simple value currently"\
"(not tuple, dict etc.)"
# these inputs are observed in parent
# converting List[int] to Tensor since module attribute is
# Union[Tensor, Module]
model._standalone_module_input_quantized_idxs = \
torch.tensor(input_quantized_idxs)
model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs)
return model
def save_state(self, observed: GraphModule, qconfig_map: Dict[str, QConfigAny]) -> None:
observed._patterns = self.patterns # type: ignore[assignment]
observed._qconfig_map = qconfig_map # type: ignore[assignment]
observed._prepare_custom_config_dict = \
self.prepare_custom_config_dict # type: ignore[assignment]
observed._node_name_to_scope = self.node_name_to_scope # type: ignore[assignment]
def restore_state(self, observed: GraphModule) -> None:
assert is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
self.patterns = observed._patterns # type: ignore[assignment]
self.prepare_custom_config_dict = \
observed._prepare_custom_config_dict # type: ignore[assignment]
self.node_name_to_scope = observed._node_name_to_scope # type: ignore[assignment]
def prepare(
self,
model: GraphModule,
qconfig_dict: Any,
node_name_to_scope: Dict[str, Tuple[str, type]],
prepare_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> ObservedGraphModule:
return self._prepare(
model, qconfig_dict, node_name_to_scope, prepare_custom_config_dict,
is_standalone_module)
def _convert(self, model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config_dict, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
self.restore_state(model)
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
# always run weight observers in the top level forward method
# for dynamic quant ops or weight only quant ops
run_weight_observers(model)
# move to cpu since we only have quantized cpu kernels
model.eval().cpu()
# mapping from fully qualified module name to module instance
# for example,
# {
# '': Model(...),
# 'linear': Linear(...),
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
# }
# We use remove_duplicate=False here because torch.cat uses
# the same activation_post_process module instance but different names
modules = dict(model.named_modules(remove_duplicate=False))
custom_module_classes = get_custom_module_class_keys(
convert_custom_config_dict,
"observed_to_quantized_custom_module_class")
matches = self._find_matches(
model.graph, modules, self.patterns,
qconfig_map,
custom_module_classes=custom_module_classes)
quantized_graph = Graph()
env: Dict[str, Tuple[Node, Optional[torch.dtype]]] = {}
graph_inputs: List[str] = []
for node in model.graph.nodes:
if node.op == 'placeholder':
graph_inputs.append(node.name)
def load_non_quantized(n: Node) -> Node:
assert n.name in env, \
'trying to load float node but did not find ' + \
'node:' + n.name + \
' in env: ' + \
str(env)
quantized_node, dtype = env[n.name]
if dtype and dtype != torch.float:
env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
return env[n.name][0]
def load_quantized(n: Node) -> Node:
assert n.name in env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in environment:' + str(env)
quantized_node, dtype = env[n.name]
assert dtype in [torch.quint8, torch.qint8, torch.float16], \
f'Expecting node {quantized_node} to be quantized but got dtype: {dtype}'
return quantized_node
def load_x(n: Node) -> Node:
assert n.name in env, \
'node ' + n.name + ' does not exist in environment'
return env[n.name][0]
def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]]
) -> Callable[[Node], Argument]:
"""
Input: quantized, which can be None, list, boolean or tuple
- if quantized is None, then we'll load the node as long as it
exists
- if quantized is a boolean, then all args will be
quantized/not quantized
- if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False)
- if quantized is a list or tuple, then arg should be a list and
the args with corresponding indexes will be quantized
Output: fn which takes arg_or_args, and loads them from the
corresponding environment depending on the value of quantized.
"""
assert quantized is None or \
isinstance(quantized, (tuple, list, bool)), type(quantized)
if isinstance(quantized, (tuple, list)) and len(quantized) == 0:
# empty tuple or list means nothing is quantized
quantized = False
def load_arg_impl(arg_or_args):
# we'll update the format of `quantized`
# to better match arg_or_args
updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized
if isinstance(quantized, (tuple, list)) and \
len(quantized) == 1 and isinstance(arg_or_args, Node):
# when argument is one Node instead of tuple, we just need to check
# 0 is in the quantized list
updated_quantized = 0 in quantized
if updated_quantized is None:
return map_arg(arg_or_args, load_x)
if isinstance(updated_quantized, bool):
return map_arg(
arg_or_args,
load_quantized if updated_quantized else load_non_quantized)
elif isinstance(updated_quantized, (tuple, list)):
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
loaded_args = []
# for now, we only support quantizing positional arguments
for i, a in enumerate(arg_or_args):
if i in updated_quantized:
loaded_args.append(map_arg(a, load_quantized))
else:
loaded_args.append(map_arg(a, load_non_quantized))
return type(arg_or_args)(loaded_args)
return load_arg_impl
def node_arg_is_quantized(node_arg: Any) -> bool:
if isinstance(node_arg, Node):
assert node_arg.name in env, \
'Expecting node_arg to be in the environment'
if node_arg.name in env:
_, dtype = env[node_arg.name]
return dtype != torch.float
else:
return False
elif isinstance(node_arg, list):
quantized = map(node_arg_is_quantized, node_arg)
if all(quantized):
return True
elif not any(quantized):
return False
else:
raise Exception(
"partially quantized inputs in list not handled yet")
else:
return False
def is_output_quantized(node: Node, obj: QuantizeHandler, qconfig: QConfigAny, modules: Dict[str, torch.nn.Module]) -> bool:
""" Check if output node is quantized or not """
assert modules is not None
# by default the output for a quantizable node is expected to be quantized
quantized = True
# Need to get correct quantized/non-quantized state forn the output
# of FixedQParamsQuantizeHandler
# TODO: we may want to try to remove the special case here
# as well
if obj.should_mark_output_quantized_from_input_quantized_status(qconfig):
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
# TODO: need to extend this to consider all relevant args instead of just arg[0]
quantized = node_arg_is_quantized(node.args[0])
# the output is unquantized if the node is not a CopyNode
# or the activation is not statically quantized
if not activation_is_statically_quantized(qconfig) or \
not obj.input_output_observed():
quantized = False
if node_return_type_is_int(node):
quantized = False
return quantized
def insert_quantize_node(node: Node, modules: Dict[str, torch.nn.Module]) -> None:
""" Given a activation_post_process module call node, insert a
quantize node"""
assert modules is not None
assert isinstance(node.target, str)
observer_module = modules[node.target]
prev_node = node.args[0]
if observer_module.dtype == torch.float32:
# copy the observer for fp32 dtype
env[node.name] = quantized_graph.node_copy(
node, load_non_quantized), torch.float
elif isinstance(prev_node, Node) and prev_node.name in env:
# if previous node is already quantized, we'll just remove the
# activation_post_process
_, prev_dtype = env[prev_node.name]
current_dtype = observer_module.dtype
if prev_dtype == current_dtype:
env[node.name] = env[prev_node.name]
else:
root_module = modules[""]
assert isinstance(prev_node, Node)
observer_dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
env[node.name] = (
quantize_node(self, load_non_quantized(prev_node),
observer_module, node, modules, quantized_graph, is_input=True),
observer_dtype)
else:
# replace activation post process with quantization ops
root_module = modules[""]
assert isinstance(node.args[0], Node)
dtype: torch.dtype = observer_module.dtype # type: ignore[assignment]
env[node.name] = (
quantize_node(self, load_non_quantized(node.args[0]),
observer_module, node, modules, quantized_graph, is_input=True),
dtype)
# additional state to override inputs to be quantized, if specified
# by the user
placeholder_node_seen_cnt = 0
output_node_seen_cnt = 0
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
"input_quantized_idxs", [])
output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
"output_quantized_idxs", [])
for node in model.graph.nodes:
if node.op == "output":
cur_output_node_idx = output_node_seen_cnt
output_node_seen_cnt += 1
if cur_output_node_idx in output_quantized_idxs:
# Result are kept quantized if the user specified the
# output_quantized_idxs override.
graph_output = map_arg(node.args[0], load_x)
else:
graph_output = map_arg(node.args[0], load_non_quantized)
quantized_graph.output(graph_output)
continue
root_node, matched, matched_pattern, obj, qconfig = \
matches.get(node.name, (None, None, None, None, None))
if root_node is node:
is_observed_standalone_module_node = (
node.op == 'call_module' and
is_observed_standalone_module(
modules[node.target])
)
if qconfig is None and not is_observed_standalone_module_node:
result = quantized_graph.node_copy(
node, load_non_quantized)
quantized = False
else:
assert obj is not None
# We will get whether the output is quantized or not before
# convert for standalone module and after convert
# for non-standalone module, since _standalone_module_output_quantized_idxs
# is only available in observed standalone module
if is_observed_standalone_module_node:
out_quant_idxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore[operator] # noqa: B950
assert len(out_quant_idxs) <= 1, "Currently standalone only support one output"
quantized = 0 in out_quant_idxs
qconfig = qconfig_map[node.name]
result = obj.convert(
self, node, qconfig, modules, quantized_graph, load_arg, is_reference=is_reference,
convert_custom_config_dict=convert_custom_config_dict)
if not is_observed_standalone_module_node:
quantized = is_output_quantized(node, obj, qconfig, modules)
if quantized:
env[node.name] = result, activation_dtype(qconfig)
else:
env[node.name] = result, torch.float
continue
elif root_node is not None:
if qconfig is None:
# This branch is hit if all of these conditions are met:
# 1. we are in a fusion pattern of multiple nodes (i.e. add-relu)
# 2. the current node is not the "root_node" of the pattern
# 3. quantization for this pattern is disabled
#
# In this case, we need to make sure to populate the env with
# intermediate nodes manually, because the QuantizeHandler.convert
# function will not be called.
result = quantized_graph.node_copy(
node, load_non_quantized)
env[node.name] = result, torch.float
continue
# handle activation post process calls
if node.op == 'call_module' and \
is_activation_post_process(modules[node.target]):
insert_quantize_node(node, modules)
elif node.op == 'placeholder':
cur_placeholder_node_idx = placeholder_node_seen_cnt
placeholder_node_seen_cnt += 1
if cur_placeholder_node_idx in input_quantized_idxs:
env[node.name] = \
quantized_graph.node_copy(
node, load_non_quantized), torch.quint8
else:
env[node.name] = \
quantized_graph.node_copy(node, load_non_quantized), torch.float
else:
# copy quantized or non-quantized node
# get_tensor_info_node like shape works for both
# quantized and non-quantized input and output a non-Tensor
# (we use None for dtype currently for non-Tensors)
if is_get_tensor_info_node(node):
env[node.name] = \
quantized_graph.node_copy(node, load_x), None
else:
env[node.name] = \
quantized_graph.node_copy(node, load_non_quantized), torch.float
# remove activation post process
act_post_process_removed_graph = Graph()
remove_env: Dict[str, Node] = {}
def load_arg_remove(a: Argument) -> Argument:
return map_arg(a, lambda node: remove_env[node.name])
for node in quantized_graph.nodes:
if node.op == 'output':
act_post_process_removed_graph.output(
map_arg(node.args[0], load_arg_remove))
continue
if node.op == 'call_module' and \
is_activation_post_process(modules[node.target]):
# remove activation post process node
remove_env[node.name] = remove_env[node.args[0].name]
else:
remove_env[node.name] = act_post_process_removed_graph.node_copy(
node, load_arg_remove)
# removes qconfig and activation_post_process modules
if _remove_qconfig_flag:
_remove_qconfig(model)
preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
model = QuantizedGraphModule(model, act_post_process_removed_graph, preserved_attributes)
return model
# Trace back from the weight node util we hit getattr, reconstruct the
# graph module with the traced nodes and run the graph module to pack the
# weight. then replace the original chain of ops with the packed weight.
def _fold_weight(self, quantized: QuantizedGraphModule) -> QuantizedGraphModule:
packed_weights = dict()
# map from folded node name to the prepacked weight name
folded_nodes = dict()
# get packed weights
for node in quantized.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
nodes_to_fold = collect_producer_nodes(node)
if nodes_to_fold is not None:
for node_to_fold in nodes_to_fold:
folded_nodes[node_to_fold.name] = node
prepacking_module = graph_module_from_producer_nodes(
quantized, nodes_to_fold)
packed_weight = prepacking_module()
packed_weights[node.name] = packed_weight
# remove folded nodes and replace the prepacking node with getattr
folded_graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
quantized_root = quantized
quantized_graph = quantized.graph
for node in quantized_graph.nodes:
prepack_node = folded_nodes.get(node.name, None)
if prepack_node is node:
packed_weight = packed_weights[node.name]
# add a prepacked attribute to root
op_node = list(prepack_node.users)[0]
module_path, _ = self.node_name_to_scope[op_node.name]
get_new_packed_weight_name = \
get_new_attr_name_with_prefix(module_path + '_packed_weight_')
packed_weight_name = get_new_packed_weight_name(quantized_root)
setattr(quantized_root, packed_weight_name, packed_weight)
# replace prepack node with a getattr node
env[node.name] = folded_graph.create_node(
'get_attr', packed_weight_name, (), {})
elif prepack_node is not None:
# remove the foled node
continue
else:
# copy other nodes
env[node.name] = folded_graph.node_copy(node, load_arg)
quantized = QuantizedGraphModule(quantized_root, folded_graph, quantized_root.preserved_attr_names)
return quantized
def convert(self, model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig: bool = True) -> QuantizedGraphModule:
quantized = self._convert(
model, is_reference, convert_custom_config_dict, is_standalone_module, _remove_qconfig_flag=_remove_qconfig)
if not is_reference:
quantized = self._fold_weight(quantized)
return quantized
def _find_matches(
self, graph: Graph, modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
qconfig_map: Dict[str, QConfigAny],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Callable] = None,
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Inputs:
- graph: an fx.Graph object
- modules: a mapping of fully qualified module name to instance,
for example, {'foo': ModuleFoo, ...}
- patterns: a mapping from a tuple of nodes in reverse order to
uninitialized QuantizeHandler subclass.
Outputs a map of
node_name ->
(node, matched_values, matched_pattern, QuantizeHandler instance,
qconfig)
For example, {
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
<CopyNodeQuantizeHandler instance>, QConfig(...)),
...
}
"""
if custom_module_classes is None:
custom_module_classes = []
if standalone_module_classes is None:
standalone_module_classes = []
if standalone_module_names is None:
standalone_module_names = []
match_map: Dict[str, MatchResult] = {}
all_matched : Set[str] = set()
def record_match(pattern, node, matched):
if isinstance(pattern, tuple):
s, *args = pattern
record_match(s, node, matched)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(subpattern, arg, matched)
else:
matched.append(node)
cache_for_no_tensor_check: Dict[Node, bool] = dict()
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, value in patterns.items():
if is_match(modules, node, pattern):
skip_this_match = False
if value is BinaryOpQuantizeHandler:
# to properly check for dtype support, we need to
# navigate to the base node of an add-relu or mul-relu
# pattern
base_node = node
if (
(node.op == 'call_function' and
node.target is torch.nn.functional.relu) or
(node.op == 'call_module' and
isinstance(modules[node.target], torch.nn.ReLU))
):
base_node = node.args[0]
this_node_qconfig = \
qconfig_map[base_node.name]
if this_node_qconfig:
dtypes = get_qconfig_dtypes(this_node_qconfig)
# TODO(future PR): update the pattern to quantize
# handler logic to take this into account.
# This needs to handle 3 cases
# 1) op and dtype is in either [is_ref or non-ref] list -> don't skip
# 2) op is not in either list (i.e. relu) -> don't skip
# 3) op is in non-ref list, but not for dtype, and op+dtype not in is_ref list -> skip
# note: the value of is_reference is unknown at prepare, so we have to cover both cases
# handle is_reference = False
skip_match_not_is_reference = (
(base_node.target in binary_op_supported_dtypes) and
(dtypes not in binary_op_supported_dtypes[base_node.target])
)
# handle is_reference = True
supported_is_reference = (
(base_node.target in binary_reference_op_supported_dtypes) and
(dtypes in binary_reference_op_supported_dtypes[base_node.target])
)
# only skip if not reference says skip and is_reference doesn't support
skip_this_match = skip_match_not_is_reference and not supported_is_reference
if not skip_this_match:
matched: List[Any] = []
record_match(pattern, node, matched)
for n in matched:
match_map[n.name] = (
node, matched, pattern, value(self, node, modules), # type: ignore[operator]
qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
break
# add custom module instances to the match result
assert modules is not None
for node in graph.nodes:
if node.op == 'call_module' and \
type(modules[node.target]) in custom_module_classes:
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, [node], None, CustomModuleQuantizeHandler(self, node, modules),
custom_module_qconfig)
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
assert modules is not None
return (
node_target in standalone_module_names or # type: ignore[operator]
type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
)
# add standalone modules to the match
for node in graph.nodes:
if node.op == 'call_module' and \
(is_standalone_module(node.target, modules) or
is_observed_standalone_module(modules[node.target])):
# add node to matched nodes
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
node, [node], None,
StandaloneModuleQuantizeHandler(self, node, modules),
custom_module_qconfig)
return match_map