blob: 6ccc4bf564db4a9f2cdd190185ed3b5201b03e30 [file] [log] [blame]
import torch
from torch.fx import ( # type: ignore
GraphModule,
Proxy,
map_arg
)
from torch.fx.graph import (
Graph,
Node,
)
from torch.quantization import (
propagate_qconfig_,
convert,
)
from ..quantization_mappings import (
get_default_qat_module_mappings,
)
from ..quantize import (
_remove_qconfig,
is_activation_post_process
)
from ..utils import (
get_combined_dict,
get_swapped_custom_module_class,
activation_is_statically_quantized,
)
from .pattern_utils import (
is_match,
get_default_quant_patterns,
get_default_output_activation_post_process_map,
input_output_observed,
Pattern,
)
from .observed_module import (
mark_observed_module,
is_observed_module,
mark_observed_standalone_module,
is_observed_standalone_module,
)
from .quantization_patterns import *
from .utils import (
_parent_name,
quantize_node,
get_custom_module_class_keys,
)
from collections import OrderedDict
import warnings
import re
from typing import Optional, Dict, Any, List, Union, Tuple
# Define helper types
QConfigAny = Union[torch.quantization.QConfig,
torch.quantization.QConfigDynamic]
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
# ------------------------
# Helper Functions
# ------------------------
# Returns a function that can get a new attribute name for module with given
# prefix, for example,
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
# >> new_name = get_new_observer_name(module)
# new_name will be an unused attribute name on module, e.g. `_observer_1`
def get_new_attr_name_with_prefix(prefix):
def get_new_attr_name(module):
def get_attr_name(i):
return prefix + str(i)
i = 0
attr_name = get_attr_name(i)
while hasattr(module, attr_name):
i += 1
attr_name = get_attr_name(i)
return attr_name
return get_new_attr_name
def collect_producer_nodes(node):
r''' Starting from a target node, trace back until we hit inpu or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
collect_producer_nodes(observed) will either return a list of nodes that
produces the observed node or None if we can't extract a self contained
graph without free variables(inputs of the forward function).
'''
nodes = [node]
frontier = [node]
while frontier:
node = frontier.pop()
all_args = list(node.args) + list(node.kwargs.values())
for arg in all_args:
if not isinstance(arg, Node):
continue
if arg.op == 'placeholder':
# hit input, can't fold in this case
return None
nodes.append(arg)
if not (arg.op == 'call_function' and arg.target == getattr):
frontier.append(arg)
return nodes
def graph_module_from_producer_nodes(root, producer_nodes):
r''' Construct a graph module from extracted producer nodes
from `collect_producer_nodes` function
Args:
root: the root module for the original graph
producer_nodes: a list of nodes we use to construct the graph
Return:
A graph module constructed from the producer nodes
'''
assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
# since we traced back from node to getattrr
producer_nodes.reverse()
graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node])
for producer_node in producer_nodes:
env[producer_node] = graph.node_copy(producer_node, load_arg)
graph.output(load_arg(producer_nodes[-1]))
graph_module = GraphModule(root, graph)
return graph_module
def assert_and_get_unique_device(module):
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
"""
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None
return device
def is_submodule_of_fake_quant(name, module, named_modules):
parent_name, _ = _parent_name(name)
return is_activation_post_process(named_modules[parent_name])
def get_flattened_qconfig_dict(qconfig_dict):
""" flatten the global, object_type and module_name qconfig
to the same qconfig_dict so that it can be used by
propagate_qconfig_ function.
"module_name_regex" is ignored for now since it's not supported
in propagate_qconfig_, but it can be fixed later.
For example:
Input: {
"": qconfig,
"object_type": [
(torch.add, qconfig)
],
"module_name": [
("conv", qconfig)
]
}
Output: {
"": qconfig,
torch.add: qconfig,
"conv": qconfig
}
"""
flattened = dict()
if '' in qconfig_dict:
flattened[''] = qconfig_dict['']
def flatten_key(key):
if key in qconfig_dict:
for obj, qconfig in qconfig_dict[key]:
flattened[obj] = qconfig
flatten_key('object_type')
flatten_key('module_name')
return flattened
def convert_dict_to_ordered_dict(qconfig_dict):
""" Convert dict in qconfig_dict to ordered dict
"""
# convert a qconfig list for a type to OrderedDict
def _convert_to_ordered_dict(key, qconfig_dict):
qconfig_dict[key] = OrderedDict(qconfig_dict.get(key, []))
_convert_to_ordered_dict('object_type', qconfig_dict)
_convert_to_ordered_dict('module_name_regex', qconfig_dict)
_convert_to_ordered_dict('module_name', qconfig_dict)
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv2d : [1],
torch.nn.functional.linear : [1],
}
# weight prepacking ops
WEIGHT_PREPACK_OPS = {
torch._ops.ops.quantized.linear_prepack,
torch._ops.ops.quantized.linear_prepack_fp16,
torch._ops.ops.quantized.conv2d_prepack,
}
class Quantizer:
def __init__(self):
# mapping from matched node to activation_post_process
# must be filled before convert
self.activation_post_process_map: Optional[
Dict[str, torch.quantization.observer.ObserverBase]] = None
# mapping from node name to qconfig that should be used for that node
# filled out for a model during _generate_qconfig_map
self.qconfig_map: Optional[Dict[str, QConfigAny]] = None
# mapping from fully qualified module name to module instance
# for example,
# {
# '': Model(...),
# 'linear': Linear(...),
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
# }
self.modules: Optional[Dict[str, torch.nn.Module]] = None
# mapping from a tuple of nodes in reverse order to uninitialized
# QuantizeHandler subclass. For example,
# {
# # match a single node
# (<class 'torch.nn.modules.conv.Conv3d'>:
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
# # match multiple nodes in reverse order
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.quantization.fx.quantize.Add'>),
# }
self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None
def _qat_swap_modules(self, root, additional_qat_module_mapping):
all_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
def _generate_qconfig_map(self,
root,
input_graph,
qconfig_dict):
global_qconfig = qconfig_dict.get('', None)
def get_module_type_qconfig(
module_type, fallback_qconfig=global_qconfig):
return qconfig_dict['object_type'].get(
module_type, fallback_qconfig)
def get_function_qconfig(
function, fallback_qconfig=global_qconfig):
return qconfig_dict['object_type'].get(function, fallback_qconfig)
def get_module_name_regex_qconfig(
module_name, fallback_qconfig=global_qconfig):
for regex_pattern, qconfig in \
qconfig_dict['module_name_regex'].items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
return fallback_qconfig
def get_module_name_qconfig(
module_name, fallback_qconfig=global_qconfig):
if module_name == '':
# module name qconfig not found
return fallback_qconfig
if module_name in qconfig_dict['module_name']:
return qconfig_dict['module_name'][module_name]
else:
parent, _ = _parent_name(module_name)
return get_module_name_qconfig(parent, fallback_qconfig)
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
def get_qconfig(module_name):
assert self.modules is not None
module_type_qconfig = \
get_module_type_qconfig(type(self.modules[module_name]))
module_name_regex_qconfig = \
get_module_name_regex_qconfig(module_name, module_type_qconfig)
module_name_qconfig = \
get_module_name_qconfig(module_name, module_name_regex_qconfig)
return module_name_qconfig
self.qconfig_map = dict()
for node in input_graph.nodes:
if node.op == 'get_attr':
module_name, _ = _parent_name(node.target)
self.qconfig_map[node.name] = get_qconfig(module_name)
elif node.op == 'call_function':
# precedence: [TODO] module_name_qconfig (need scope support
# from fx)
# > function_qconfig > global_qconfig
function_qconfig = get_function_qconfig(node.target)
self.qconfig_map[node.name] = function_qconfig
elif node.op == 'call_method':
self_obj = node.args[0]
# qconfig for call_method should be the same as the `self`
# object for the call
if self_obj.name in self.qconfig_map:
qconfig = self.qconfig_map[self_obj.name]
else:
# need scope info for each node to support this
warnings.warn(
"Scope info is not yet supported, taking default " +
"qconfig for value {}".format(node.name))
qconfig = get_qconfig('')
self.qconfig_map[node.name] = qconfig
elif node.op == 'call_module':
module_qconfig = get_qconfig(node.target)
# regex is not supported eager mode propagate_qconfig_, we'll
# need to set the qconfig explicitly here in case regex
# is used
assert self.modules is not None
self.modules[node.target].qconfig = module_qconfig
self.qconfig_map[node.name] = module_qconfig
def _prepare(self, model, qconfig_dict, prepare_custom_config_dict,
is_standalone_module):
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
When we are preparing a standalone module:
input of the module is observed in parent module, output of the module
is observed in the standalone module.
Returns:
model(GraphModule): prepared standalone module with following
attributes:
_standalone_module_observed_input_idxs(List[Int]): a list of
indexes for the graph inputs that needs to be observed in
parent module
_output_is_observed(Bool): a boolean variable indicate whether
the output of the custom module is observed or not
"""
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
additional_quant_patterns = \
prepare_custom_config_dict.get("additional_quant_pattern", {})
self.patterns = get_combined_dict(
get_default_quant_patterns(), additional_quant_patterns)
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict)
if model.training:
additional_qat_module_mapping = prepare_custom_config_dict.get(
"additional_qat_module_mapping", {})
self._qat_swap_modules(model, additional_qat_module_mapping)
self.modules = dict(model.named_modules())
convert_dict_to_ordered_dict(qconfig_dict)
# map from node name to qconfig, used in _find_matches
self._generate_qconfig_map(model, model.graph, qconfig_dict)
# match the patterns that will get quantized
standalone_module_names = prepare_custom_config_dict.get(
"standalone_module_name", None)
standalone_module_classes = prepare_custom_config_dict.get(
"standalone_module_class", None)
custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
matches = self._find_matches(
model.graph, self.modules, self.patterns, standalone_module_names,
standalone_module_classes, custom_module_classes)
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
# initialize an DefaultQuantizeHandler object for each
quants = self._find_quants(model.graph, matches)
self.activation_post_process_map = dict()
env: Dict[Any, Any] = {}
observed_graph = Graph()
observed_node_names_set = set()
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
# indexes for the inputs that needs to be observed
standalone_module_observed_input_idxs = []
graph_inputs = []
for node in model.graph.nodes:
if node.op == 'placeholder':
graph_inputs.append(node.name)
get_new_observer_name = get_new_attr_name_with_prefix(
'activation_post_process_')
model_device = assert_and_get_unique_device(model)
def insert_observer(node, observer):
"""Insert observer for node by modifying the observed_graph and
attach observer module to the model
Args:
node: Node
observer: observer/fake_quantize module instance
"""
# respect device affinity when adding observers
if model_device:
observer.to(model_device)
# add observer module as attribute
prefix = node.name + '_activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
# put observer instance activation_post_process map
assert self.activation_post_process_map is not None
self.activation_post_process_map[node.name] = observer
# insert observer call
env[node.name] = observed_graph.create_node(
'call_module', observer_name, (load_arg(node),), {})
observed_node_names_set.add(node.name)
def insert_observer_for_special_module(quantize_handler):
""" Insert observer for custom module and standalone module
Returns: standalone_module_input_idxs: the indexs for inputs that
needs to be observed by parent module
"""
standalone_module_input_idxs = None
assert self.modules is not None
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
custom_module = self.modules[node.target]
custom_module_class_mapping = prepare_custom_config_dict.get(
"float_to_observed_custom_module_class", {})
observed_custom_module_class = \
get_swapped_custom_module_class(
custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name, observed_custom_module)
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = self.modules[node.target]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
observed_standalone_module = \
prepare(standalone_module, {"": qconfig})
observed_standalone_module.qconfig = qconfig
standalone_module_input_idxs = observed_standalone_module.\
_standalone_module_observed_input_idxs
observed_standalone_module = mark_observed_standalone_module(
observed_standalone_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name,
observed_standalone_module)
self.modules[node.target] = observed_standalone_module
return standalone_module_input_idxs
def insert_observer_for_output_of_the_node(
node,
quantize_handler,
qconfig,
standalone_module_input_idxs):
""" Insert observer/fake_quantize module for output of the observed
module if needed
"""
# don't need to insert observer for output if activation does not
# need to be statically quantized
assert self.modules is not None
if activation_is_statically_quantized(qconfig):
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
and model.training:
# we only insert fake quantize module in qat
assert pattern is not None
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(
pattern, None)
assert activation_post_process_ctr is not None, \
"activation_post_process constructor not provided " + \
"for pattern:" + str(pattern)
insert_observer(node, activation_post_process_ctr())
elif (isinstance(quantize_handler,
FixedQParamsOpQuantizeHandler) and
not model.training) or \
isinstance(quantize_handler, CopyNode):
# inserting observers for output of observed module, or
# mark the output as observed
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
def is_observed(input_arg):
if isinstance(input_arg, Node):
return input_arg.name in observed_node_names_set
elif isinstance(input_arg, list):
return all(map(is_observed, input_arg))
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
elif ((isinstance(quantize_handler, Add) or
isinstance(quantize_handler, Mul)) and
quantize_handler.num_node_args == 1):
assert matched_nodes is not None
input_node = matched_nodes[-1] # first node in the sequence
def input_is_observed(arg):
return (isinstance(arg, Node) and
arg.name in observed_node_names_set)
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
if (input_is_observed(input_node.args[0]) or
input_is_observed(input_node.args[1])):
observed_node_names_set.add(node.name)
elif isinstance(quantize_handler,
StandaloneModuleQuantizeHandler):
assert node.op == 'call_module'
output_is_observed = \
self.modules[node.target]._output_is_observed
if output_is_observed:
observed_node_names_set.add(node.name)
elif (quantize_handler.all_node_args and
input_output_observed(quantize_handler)):
# observer for outputs
new_observer = qconfig.activation()
insert_observer(node, new_observer)
# insert observer for input of standalone module
if standalone_module_input_idxs is not None:
for idx in standalone_module_input_idxs:
if node.args[idx].name not in observed_node_names_set:
new_observer = qconfig.activation()
insert_observer(node.args[idx], new_observer)
def insert_observer_for_input_arg_of_observed_node(arg):
"""
Input:
arg: input arg node for another observed node, e.g.
input activaiton for functional linear node
"""
if node.name not in observed_node_names_set and node.name in quants:
if is_standalone_module and node.name in graph_inputs:
# we'll insert observer for input of standalone module
# in parent graph
standalone_module_observed_input_idxs.append(
graph_inputs.index(node.name))
return
_, activation_post_process_ctr = quants[node.name]
if activation_post_process_ctr is not None:
insert_observer(node, activation_post_process_ctr())
result_node : Optional[Node] = None
for node in model.graph.nodes:
if node.op == 'output':
observed_graph.output(load_arg(node.args[0]))
result_node = node
continue
if node.name in observed_node_names_set:
continue
root_node, matched_nodes, pattern, obj, qconfig = matches.get(
node.name, (None, None, None, None, None))
if root_node is None:
env[node.name] = observed_graph.node_copy(node, load_arg)
elif root_node is node:
env[node.name] = observed_graph.node_copy(node, load_arg)
# index for input of custom module that needs to be observed in
# parent
if qconfig is not None:
standalone_module_input_idxs = \
insert_observer_for_special_module(obj)
insert_observer_for_output_of_the_node(
node, obj, qconfig, standalone_module_input_idxs)
else:
env[node.name] = observed_graph.node_copy(node, load_arg)
insert_observer_for_input_arg_of_observed_node(node)
model = GraphModule(model, observed_graph)
self.save_state(model)
model = mark_observed_module(model)
if is_standalone_module:
assert result_node is not None
assert isinstance(result_node.args[0], Node), \
'standalone module returning dict is not yet supported'
# indicator for whether output is observed or not.
# This used for correctly quantize standalone modules
output_is_observed = \
result_node.args[0].name in observed_node_names_set
model._standalone_module_observed_input_idxs = \
standalone_module_observed_input_idxs
model._output_is_observed = output_is_observed
return model
def save_state(self, observed):
observed._activation_post_process_map = self.activation_post_process_map
observed._patterns = self.patterns
observed._qconfig_map = self.qconfig_map
def restore_state(self, observed):
assert is_observed_module(observed), \
'incoming model must be produced by prepare_fx'
self.activation_post_process_map = observed._activation_post_process_map
self.patterns = observed._patterns
self.qconfig_map = observed._qconfig_map
def prepare(self, model, qconfig_dict, prepare_custom_config_dict=None,
is_standalone_module=False):
return self._prepare(
model, qconfig_dict, prepare_custom_config_dict,
is_standalone_module)
def _run_weight_observers(self, observed):
r''' Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
Note that the observers of dynamic quant or weight only quant ops are
run during the convert step.
'''
for node in observed.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if i in WEIGHT_INDEX_DICT[node.target]:
# node_arg is weight
weight_observer_nodes = collect_producer_nodes(node_arg)
if weight_observer_nodes is not None:
weight_observer_module = \
graph_module_from_producer_nodes(
observed, weight_observer_nodes)
# run the weight observer
weight_observer_module()
return
def _convert(self, model, debug=False, convert_custom_config_dict=None,
is_standalone_module=False):
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
For standalone module: the inputs will be quantized by parent module,
checks `_standalone_module_observed_input_idxs` of
input observed model and will treat these inputs as quantized
also will not dequantize the final output.
Returns a quantized standalone module which accepts quantized input
(if needed) and produces quantized output (if needed).
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
self.restore_state(model)
# always run weight observers in the top level forward method
# for dynamic quant ops or weight only quant ops
self._run_weight_observers(model)
# move to cpu since we only have quantized cpu kernels
model.eval().cpu()
self.modules = dict(model.named_modules())
custom_module_classes = get_custom_module_class_keys(
convert_custom_config_dict,
"observed_to_quantized_custom_module_class")
matches = self._find_matches(
model.graph, self.modules, self.patterns,
custom_module_classes=custom_module_classes)
quants = self._find_quants(model.graph, matches)
self.quantized_graph = Graph()
env: Dict[Any, Any] = {}
quant_env: Dict[Any, Any] = {}
graph_inputs = []
for node in model.graph.nodes:
if node.op == 'placeholder':
graph_inputs.append(node.name)
def load_non_quantized(n):
if n.name not in env:
assert n.name in quant_env, \
'trying to load float node but did not find ' + \
'node:' + n.name + \
' in quantized or non quantized environment, env: ' + \
str(env) + ' quant_env:' + str(quant_env)
env[n.name] = Proxy(quant_env[n.name]).dequantize().node
return env[n.name]
def load_quantized(n):
if n.name not in quant_env:
assert n.name in env, \
'trying to load quantized node but did not find node:' + \
n.name + ' in float environment:' + str(env)
assert n.name in quants, \
'did not find quant object for node:' + n.name
quant = quants[n.name][0]
quant_env[n.name] = quant.convert(self, env[n.name])
return quant_env[n.name]
def load_x(n):
assert n.name in env or n.name in quant_env, \
'node ' + n.name + ' does not exist in either environment'
if n.name in quant_env:
return quant_env[n.name]
else:
return env[n.name]
def load_arg(quantized):
"""
Input: quantized, which can be None, list, boolean or tuple
- if quantized is a list or tuple, then arg should be a list and
the args with corresponding indexes will be quantized
- if quantized is a boolean, then all args will be
quantized/not quantized
- if quantized is None, then we'll load the node as long as it
exists
Output: fn which takes arg_or_args, and loads them from the
corresponding environment depending on the value of quantized.
"""
assert quantized is None or \
isinstance(quantized, (tuple, list, bool)), type(quantized)
def load_arg_impl(arg_or_args):
if quantized is None:
return map_arg(arg_or_args, load_x)
if isinstance(quantized, bool):
return map_arg(
arg_or_args,
load_quantized if quantized else load_non_quantized)
elif isinstance(quantized, (tuple, list)):
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
loaded_args = []
# for now, we only support quantizing positional arguments
for i, a in enumerate(arg_or_args):
if i in quantized:
loaded_args.append(map_arg(a, load_quantized))
else:
loaded_args.append(map_arg(a, load_non_quantized))
return type(arg_or_args)(loaded_args)
return load_arg_impl
def is_quantized(node):
if isinstance(node, Node):
assert node.name in env or node.name in quant_env, \
'Expecting node to be in the environment'
# there might be nodes appearing in both environemnts, but
# quant_env will take precedence
if node.name in quant_env:
return True
elif node.name in env:
return False
elif isinstance(node, list):
quantized = map(is_quantized, node)
if all(quantized):
return True
elif not any(quantized):
return False
else:
raise Exception(
"partially quantized inputs in list not handled yet")
def is_output_quantized(node) -> bool:
""" Check if output node is quantized or not """
assert self.modules is not None
if node.op == 'call_module' and \
is_observed_standalone_module(self.modules[node.target]):
quantized = bool(self.modules[node.target]._output_is_observed)
else:
quantized = True
# Need to get correct quantized/non-quantized state for the output
# of CopyNode
if type(obj) in [
CopyNode,
FixedQParamsOpQuantizeHandler
]:
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
quantized = is_quantized(node.args[0])
if not activation_is_statically_quantized(qconfig) or \
not input_output_observed(obj):
quantized = False
return quantized
def insert_quantize_node(node):
""" Given a activation_post_process module call node, insert a
quantize node"""
assert self.modules is not None
observer_module = self.modules[node.target]
prev_node = node.args[0]
if observer_module.dtype == torch.float16:
# activations are not quantized for
# fp16 dynamic quantization
# copy the activaiton_post_process node here
# since we may need it when we insert prepack
# op for weight of linear, this will be removed
# later in a separate pass
env[node.name] = self.quantized_graph.node_copy(
node, load_non_quantized)
elif prev_node.name in quant_env:
# if previous node is already quantized, we'll just remove the
# activation_post_process
quant_env[node.name] = quant_env[prev_node.name]
else:
# replace activation post process with quantization ops
root_module = self.modules[""]
quant_env[node.name] = quantize_node(
root_module, self.quantized_graph,
load_non_quantized(node.args[0]), observer_module)
for node in model.graph.nodes:
if node.op == 'output':
if is_standalone_module:
# result are kept quantized in the quantized standalone
# module
graph_output = map_arg(node.args[0], load_x)
else:
graph_output = map_arg(node.args[0], load_non_quantized)
self.quantized_graph.output(graph_output)
continue
root_node, matched, matched_pattern, obj, qconfig = \
matches.get(node.name, (None, None, None, None, None))
if root_node is node:
if qconfig is None:
result = self.quantized_graph.node_copy(
node, load_non_quantized)
quantized = False
else:
assert obj is not None
result = obj.convert(
self, node, load_arg, debug=debug,
convert_custom_config_dict=convert_custom_config_dict)
quantized = is_output_quantized(node)
if quantized:
quant_env[node.name] = result
else:
env[node.name] = result
continue
elif root_node is not None:
continue
# handle activation post process calls
if node.op == 'call_module' and \
is_activation_post_process(self.modules[node.target]):
insert_quantize_node(node)
elif (is_standalone_module and node.op == 'placeholder' and
graph_inputs.index(node.name) in
model._standalone_module_observed_input_idxs):
# the node is quantized in parent module
quant_env[node.name] = \
self.quantized_graph.node_copy(node, load_non_quantized)
else:
# copy quantized or non-quantized node
env[node.name] = \
self.quantized_graph.node_copy(node, load_non_quantized)
# remove activation post process
act_post_process_removed_graph = Graph()
env = {}
def load_arg(a): # type: ignore
return map_arg(a, lambda node: env[node.name])
for node in self.quantized_graph.nodes:
if node.op == 'output':
act_post_process_removed_graph.output(
map_arg(node.args[0], load_arg))
continue
if node.op == 'call_module' and \
is_activation_post_process(self.modules[node.target]):
# remove activation post process node
env[node.name] = env[node.args[0].name]
else:
env[node.name] = act_post_process_removed_graph.node_copy(
node, load_arg)
# removes qconfig and activation_post_process modules
_remove_qconfig(model)
model = GraphModule(model, act_post_process_removed_graph)
return model
# Trace back from the weight node util we hit getattr, reconstruct the
# graph module with the traced nodes and run the graph module to pack the
# weight. then replace the original chain of ops with the packed weight.
def _fold_weight(self, quantized):
packed_weights = dict()
# map from folded node name to the prepacked weight name
folded_nodes = dict()
# get packed weights
for node in quantized.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
nodes_to_fold = collect_producer_nodes(node)
if nodes_to_fold is not None:
for node_to_fold in nodes_to_fold:
folded_nodes[node_to_fold.name] = node
prepacking_module = graph_module_from_producer_nodes(
quantized, nodes_to_fold)
packed_weight = prepacking_module()
packed_weights[node.name] = packed_weight
# remove folded nodes and replace the prepacking node with getattr
folded_graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
get_new_packed_weight_name = \
get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
quantized_root = quantized
quantized_graph = quantized.graph
for node in quantized_graph.nodes:
prepack_node = folded_nodes.get(node.name, None)
if prepack_node is node:
packed_weight = packed_weights[node.name]
# add a prepacked attribute to root
packed_weight_name = get_new_packed_weight_name(quantized_root)
setattr(quantized_root, packed_weight_name, packed_weight)
# replace prepack node with a getattr node
env[node.name] = folded_graph.create_node(
'get_attr', packed_weight_name, (), {})
elif prepack_node is not None:
# remove the foled node
continue
else:
# copy other nodes
env[node.name] = folded_graph.node_copy(node, load_arg)
quantized = GraphModule(quantized_root, folded_graph)
return quantized
def convert(self, model, debug=False, convert_custom_config_dict=None,
is_standalone_module=False):
quantized = self._convert(
model, debug, convert_custom_config_dict, is_standalone_module)
if not debug:
quantized = self._fold_weight(quantized)
return quantized
def _find_matches(
self, graph, modules, patterns,
standalone_module_names=None,
standalone_module_classes=None,
custom_module_classes=None) -> Dict[str, MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Inputs:
- graph: an fx.Graph object
- modules: a mapping of fully qualified module name to instance,
for example, {'foo': ModuleFoo, ...}
- patterns: a mapping from a tuple of nodes in reverse order to
uninitialized QuantizeHandler subclass.
Outputs a map of
node_name ->
(node, matched_values, matched_pattern, QuantizeHandler instance,
qconfig)
For example, {
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
<CopyNode instance>, QConfig(...)),
...
}
"""
if custom_module_classes is None:
custom_module_classes = []
if standalone_module_classes is None:
standalone_module_classes = []
if standalone_module_names is None:
standalone_module_names = []
match_map: Dict[str, MatchResult] = {}
all_matched = set()
def record_match(pattern, node, matched):
if isinstance(pattern, tuple):
s, *args = pattern
record_match(s, node, matched)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(subpattern, arg, matched)
else:
matched.append(node)
assert self.qconfig_map is not None
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, value in patterns.items():
if is_match(modules, node, pattern):
matched: List[Any] = []
record_match(pattern, node, matched)
for n in matched:
match_map[n.name] = (
node, matched, pattern, value(self, node),
self.qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
break
# add custom module instances to the match result
assert self.modules is not None
for node in graph.nodes:
if node.op == 'call_module' and \
type(self.modules[node.target]) in custom_module_classes:
custom_module_qconfig = self.qconfig_map[node.name]
match_map[node.name] = (
node, [node], None, CustomModuleQuantizeHandler(self, node),
custom_module_qconfig)
def is_standalone_module(node_target):
assert self.modules is not None
return node_target in standalone_module_names or \
type(self.modules[node_target]) in standalone_module_classes
# add standalone modules to the match
for node in graph.nodes:
if node.op == 'call_module' and \
(is_standalone_module(node.target) or
is_observed_standalone_module(self.modules[node.target])):
# add node to matched nodes
custom_module_qconfig = self.qconfig_map[node.name]
match_map[node.name] = (
node, [node], None,
StandaloneModuleQuantizeHandler(self, node),
custom_module_qconfig)
return match_map
def _find_quants(self, graph, matches):
"""
Takes the nodes in the input graph and pending matches, and finds and
returns the input and output nodes which need to be quantized.
Inputs:
- graph: an fx.Graph object
- matches: output of self._find_matches function
Outputs a map of
node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler),
activation_post_process (observer/fake_quantize module) constructor)
"""
quants: Dict[Any, Any] = {}
def visit(node, matched_pattern, qconfig):
def visit_arg(arg):
is_weight = False
if isinstance(node, Node) and node.op == 'call_function' and \
node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in \
WEIGHT_INDEX_DICT[node.target]: # type: ignore
is_weight = True
if qconfig is not None and \
(activation_is_statically_quantized(qconfig) or is_weight):
act_post_process_ctr = qconfig.weight if is_weight else \
qconfig.activation
quants[arg.name] = (
DefaultQuantizeHandler(self, arg), qconfig, is_weight)
# overwrite the constructor from qconfig
act_post_process_ctr = \
get_default_output_activation_post_process_map().get(
matched_pattern,
act_post_process_ctr)
# overwrite previous activation post process constructor if
# necessary
quants[arg.name] = (
DefaultQuantizeHandler(self, arg), act_post_process_ctr)
return visit_arg
for node in graph.nodes:
if node.name in matches:
root_node, matched_nodes, matched_pattern, quantize_handler, \
qconfig = matches[node.name]
# don't attach observer/fake_quant for CopyNode
if isinstance(quantize_handler, CopyNode):
qconfig = None
if root_node is node and \
input_output_observed(quantize_handler):
# matched_nodes[-1] is the first op in the sequence and
# matched_nodes[0] is the last op in the sequence
# inputs
# matched_pattern is set to None for inputs because
# we only want to select QuantizeHandler object based
# on pattern for output, inputs will always use
# DefaultQuantizeHandler
map_arg(matched_nodes[-1].args, visit(matched_nodes[-1],
None, qconfig))
map_arg(matched_nodes[-1].kwargs, visit(matched_nodes[-1],
None, qconfig))
# output
# we don't insert observer for output of standalone module
if not isinstance(
quantize_handler, StandaloneModuleQuantizeHandler):
# passing in matched_pattern here so that we can
# customize activation_post_process constructor for
# output based on the pattern, e.g.
# for sigmoid op we'll use
# default_affine_fixed_qparam_fake_quant
map_arg(matched_nodes[0],
visit(None, matched_pattern, qconfig))
return quants