quant: make each line of fx/quantize.py <=80 chars (#48357)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48357
Cleans up the long lines in `torch/quantization/fx/quantize.py`
to fit the 80 character limit, so it's easier to read and looks
better on FB's tools.
In the future we can consider adding a linter for this.
Test Plan:
CI
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D25140833
fbshipit-source-id: 78605d58eda0184eb82f510baec26685a34870e2
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 2e8eb24..6ccc4bf 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -61,7 +61,8 @@
# Define helper types
-QConfigAny = Union[torch.quantization.QConfig, torch.quantization.QConfigDynamic]
+QConfigAny = Union[torch.quantization.QConfig,
+ torch.quantization.QConfigDynamic]
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]
@@ -69,8 +70,8 @@
# Helper Functions
# ------------------------
-# Returns a function that can get a new attribute name for module with given prefix
-# for example,
+# Returns a function that can get a new attribute name for module with given
+# prefix, for example,
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
# >> new_name = get_new_observer_name(module)
# new_name will be an unused attribute name on module, e.g. `_observer_1`
@@ -93,9 +94,9 @@
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
- collect_producer_nodes(observed) will either return a list of nodes that produces
- the observed node or None if we can't extract a self contained graph without
- free variables(inputs of the forward function).
+ collect_producer_nodes(observed) will either return a list of nodes that
+ produces the observed node or None if we can't extract a self contained
+ graph without free variables(inputs of the forward function).
'''
nodes = [node]
frontier = [node]
@@ -219,7 +220,8 @@
def __init__(self):
# mapping from matched node to activation_post_process
# must be filled before convert
- self.activation_post_process_map: Optional[Dict[str, torch.quantization.observer.ObserverBase]] = None
+ self.activation_post_process_map: Optional[
+ Dict[str, torch.quantization.observer.ObserverBase]] = None
# mapping from node name to qconfig that should be used for that node
# filled out for a model during _generate_qconfig_map
self.qconfig_map: Optional[Dict[str, QConfigAny]] = None
@@ -241,11 +243,12 @@
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
# <class 'torch.quantization.fx.quantize.Add'>),
# }
- self.patterns: Optional[Dict[Pattern, torch.quantization.fx.quantization_patterns.QuantizeHandler]] = None
+ self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None
def _qat_swap_modules(self, root, additional_qat_module_mapping):
- all_mappings = get_combined_dict(get_default_qat_module_mappings(), additional_qat_module_mapping)
+ all_mappings = get_combined_dict(
+ get_default_qat_module_mappings(), additional_qat_module_mapping)
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
def _generate_qconfig_map(self,
@@ -256,7 +259,8 @@
def get_module_type_qconfig(
module_type, fallback_qconfig=global_qconfig):
- return qconfig_dict['object_type'].get(module_type, fallback_qconfig)
+ return qconfig_dict['object_type'].get(
+ module_type, fallback_qconfig)
def get_function_qconfig(
function, fallback_qconfig=global_qconfig):
@@ -264,7 +268,8 @@
def get_module_name_regex_qconfig(
module_name, fallback_qconfig=global_qconfig):
- for regex_pattern, qconfig in qconfig_dict['module_name_regex'].items():
+ for regex_pattern, qconfig in \
+ qconfig_dict['module_name_regex'].items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
@@ -282,8 +287,8 @@
return get_module_name_qconfig(parent, fallback_qconfig)
# get qconfig for module_name,
- # fallback to module_name_regex_qconfig, module_type_qconfig, global_qconfig
- # if necessary
+ # fallback to module_name_regex_qconfig, module_type_qconfig,
+ # global_qconfig if necessary
def get_qconfig(module_name):
assert self.modules is not None
module_type_qconfig = \
@@ -300,54 +305,64 @@
module_name, _ = _parent_name(node.target)
self.qconfig_map[node.name] = get_qconfig(module_name)
elif node.op == 'call_function':
- # precedence: [TODO] module_name_qconfig (need scope support from fx)
+ # precedence: [TODO] module_name_qconfig (need scope support
+ # from fx)
# > function_qconfig > global_qconfig
function_qconfig = get_function_qconfig(node.target)
self.qconfig_map[node.name] = function_qconfig
elif node.op == 'call_method':
self_obj = node.args[0]
- # qconfig for call_method should be the same as the `self` object for the call
+ # qconfig for call_method should be the same as the `self`
+ # object for the call
if self_obj.name in self.qconfig_map:
qconfig = self.qconfig_map[self_obj.name]
else:
# need scope info for each node to support this
- warnings.warn("Scope info is not yet supported, taking default qconfig for value {}".format(node.name))
+ warnings.warn(
+ "Scope info is not yet supported, taking default " +
+ "qconfig for value {}".format(node.name))
qconfig = get_qconfig('')
self.qconfig_map[node.name] = qconfig
elif node.op == 'call_module':
module_qconfig = get_qconfig(node.target)
- # regex is not supported eager mode propagate_qconfig_, we'll need to
- # set the qconfig explicitly here in case regex
+ # regex is not supported eager mode propagate_qconfig_, we'll
+ # need to set the qconfig explicitly here in case regex
# is used
assert self.modules is not None
self.modules[node.target].qconfig = module_qconfig
self.qconfig_map[node.name] = module_qconfig
- def _prepare(self, model, qconfig_dict, prepare_custom_config_dict, is_standalone_module):
- """ standalone_module means it a submodule that is not inlined in parent module,
- and will be quantized separately as one unit.
+ def _prepare(self, model, qconfig_dict, prepare_custom_config_dict,
+ is_standalone_module):
+ """ standalone_module means it a submodule that is not inlined in
+ parent module, and will be quantized separately as one unit.
When we are preparing a standalone module:
input of the module is observed in parent module, output of the module
is observed in the standalone module.
Returns:
- model(GraphModule): prepared standalone module with following attributes:
- _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that
- needs to be observed in parent module
- _output_is_observed(Bool): a boolean variable indicate whether the output of the
- custom module is observed or not
+ model(GraphModule): prepared standalone module with following
+ attributes:
+ _standalone_module_observed_input_idxs(List[Int]): a list of
+ indexes for the graph inputs that needs to be observed in
+ parent module
+ _output_is_observed(Bool): a boolean variable indicate whether
+ the output of the custom module is observed or not
"""
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
- additional_quant_patterns = prepare_custom_config_dict.get("additional_quant_pattern", {})
- self.patterns = get_combined_dict(get_default_quant_patterns(), additional_quant_patterns)
+ additional_quant_patterns = \
+ prepare_custom_config_dict.get("additional_quant_pattern", {})
+ self.patterns = get_combined_dict(
+ get_default_quant_patterns(), additional_quant_patterns)
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict)
if model.training:
- additional_qat_module_mapping = prepare_custom_config_dict.get("additional_qat_module_mapping", {})
+ additional_qat_module_mapping = prepare_custom_config_dict.get(
+ "additional_qat_module_mapping", {})
self._qat_swap_modules(model, additional_qat_module_mapping)
self.modules = dict(model.named_modules())
@@ -357,11 +372,15 @@
self._generate_qconfig_map(model, model.graph, qconfig_dict)
# match the patterns that will get quantized
- standalone_module_names = prepare_custom_config_dict.get("standalone_module_name", None)
- standalone_module_classes = prepare_custom_config_dict.get("standalone_module_class", None)
- custom_module_classes = get_custom_module_class_keys(prepare_custom_config_dict, "float_to_observed_custom_module_class")
+ standalone_module_names = prepare_custom_config_dict.get(
+ "standalone_module_name", None)
+ standalone_module_classes = prepare_custom_config_dict.get(
+ "standalone_module_class", None)
+ custom_module_classes = get_custom_module_class_keys(
+ prepare_custom_config_dict, "float_to_observed_custom_module_class")
matches = self._find_matches(
- model.graph, self.modules, self.patterns, standalone_module_names, standalone_module_classes, custom_module_classes)
+ model.graph, self.modules, self.patterns, standalone_module_names,
+ standalone_module_classes, custom_module_classes)
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
@@ -383,7 +402,8 @@
if node.op == 'placeholder':
graph_inputs.append(node.name)
- get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
+ get_new_observer_name = get_new_attr_name_with_prefix(
+ 'activation_post_process_')
model_device = assert_and_get_unique_device(model)
def insert_observer(node, observer):
@@ -405,21 +425,24 @@
assert self.activation_post_process_map is not None
self.activation_post_process_map[node.name] = observer
# insert observer call
- env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {})
+ env[node.name] = observed_graph.create_node(
+ 'call_module', observer_name, (load_arg(node),), {})
observed_node_names_set.add(node.name)
def insert_observer_for_special_module(quantize_handler):
""" Insert observer for custom module and standalone module
- Returns: standalone_module_input_idxs: the indexs for inputs that needs
- to be observed by parent module
+ Returns: standalone_module_input_idxs: the indexs for inputs that
+ needs to be observed by parent module
"""
standalone_module_input_idxs = None
assert self.modules is not None
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
custom_module = self.modules[node.target]
- custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
+ custom_module_class_mapping = prepare_custom_config_dict.get(
+ "float_to_observed_custom_module_class", {})
observed_custom_module_class = \
- get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig)
+ get_swapped_custom_module_class(
+ custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
@@ -427,13 +450,18 @@
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = self.modules[node.target]
- prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
- observed_standalone_module = prepare(standalone_module, {"": qconfig})
+ prepare = \
+ torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
+ observed_standalone_module = \
+ prepare(standalone_module, {"": qconfig})
observed_standalone_module.qconfig = qconfig
- standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
- observed_standalone_module = mark_observed_standalone_module(observed_standalone_module)
+ standalone_module_input_idxs = observed_standalone_module.\
+ _standalone_module_observed_input_idxs
+ observed_standalone_module = mark_observed_standalone_module(
+ observed_standalone_module)
parent_name, name = _parent_name(node.target)
- setattr(self.modules[parent_name], name, observed_standalone_module)
+ setattr(self.modules[parent_name], name,
+ observed_standalone_module)
self.modules[node.target] = observed_standalone_module
return standalone_module_input_idxs
@@ -442,26 +470,30 @@
quantize_handler,
qconfig,
standalone_module_input_idxs):
- """ Insert observer/fake_quantize module for output of the observed module
- if needed
+ """ Insert observer/fake_quantize module for output of the observed
+ module if needed
"""
# don't need to insert observer for output if activation does not
# need to be statically quantized
assert self.modules is not None
if activation_is_statically_quantized(qconfig):
- if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and model.training:
+ if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
+ and model.training:
# we only insert fake quantize module in qat
assert pattern is not None
activation_post_process_ctr = \
- get_default_output_activation_post_process_map().get(pattern, None)
+ get_default_output_activation_post_process_map().get(
+ pattern, None)
assert activation_post_process_ctr is not None, \
- "activation_post_process constructor not provided for " + \
- "pattern:" + str(pattern)
+ "activation_post_process constructor not provided " + \
+ "for pattern:" + str(pattern)
insert_observer(node, activation_post_process_ctr())
- elif (isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and
- not model.training) or isinstance(quantize_handler, CopyNode):
- # inserting observers for output of observed module, or mark the output
- # as observed
+ elif (isinstance(quantize_handler,
+ FixedQParamsOpQuantizeHandler) and
+ not model.training) or \
+ isinstance(quantize_handler, CopyNode):
+ # inserting observers for output of observed module, or
+ # mark the output as observed
assert node.op in [
'call_module',
'call_function',
@@ -476,25 +508,31 @@
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
- elif ((isinstance(quantize_handler, Add) or isinstance(quantize_handler, Mul)) and
+ elif ((isinstance(quantize_handler, Add) or
+ isinstance(quantize_handler, Mul)) and
quantize_handler.num_node_args == 1):
assert matched_nodes is not None
input_node = matched_nodes[-1] # first node in the sequence
def input_is_observed(arg):
- return isinstance(arg, Node) and arg.name in observed_node_names_set
+ return (isinstance(arg, Node) and
+ arg.name in observed_node_names_set)
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
- if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
+ if (input_is_observed(input_node.args[0]) or
+ input_is_observed(input_node.args[1])):
observed_node_names_set.add(node.name)
- elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
+ elif isinstance(quantize_handler,
+ StandaloneModuleQuantizeHandler):
assert node.op == 'call_module'
- output_is_observed = self.modules[node.target]._output_is_observed
+ output_is_observed = \
+ self.modules[node.target]._output_is_observed
if output_is_observed:
observed_node_names_set.add(node.name)
- elif quantize_handler.all_node_args and input_output_observed(quantize_handler):
+ elif (quantize_handler.all_node_args and
+ input_output_observed(quantize_handler)):
# observer for outputs
new_observer = qconfig.activation()
insert_observer(node, new_observer)
@@ -516,7 +554,8 @@
if is_standalone_module and node.name in graph_inputs:
# we'll insert observer for input of standalone module
# in parent graph
- standalone_module_observed_input_idxs.append(graph_inputs.index(node.name))
+ standalone_module_observed_input_idxs.append(
+ graph_inputs.index(node.name))
return
_, activation_post_process_ctr = quants[node.name]
if activation_post_process_ctr is not None:
@@ -531,14 +570,17 @@
if node.name in observed_node_names_set:
continue
- root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None))
+ root_node, matched_nodes, pattern, obj, qconfig = matches.get(
+ node.name, (None, None, None, None, None))
if root_node is None:
env[node.name] = observed_graph.node_copy(node, load_arg)
elif root_node is node:
env[node.name] = observed_graph.node_copy(node, load_arg)
- # index for input of custom module that needs to be observed in parent
+ # index for input of custom module that needs to be observed in
+ # parent
if qconfig is not None:
- standalone_module_input_idxs = insert_observer_for_special_module(obj)
+ standalone_module_input_idxs = \
+ insert_observer_for_special_module(obj)
insert_observer_for_output_of_the_node(
node, obj, qconfig, standalone_module_input_idxs)
else:
@@ -555,8 +597,10 @@
'standalone module returning dict is not yet supported'
# indicator for whether output is observed or not.
# This used for correctly quantize standalone modules
- output_is_observed = result_node.args[0].name in observed_node_names_set
- model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs
+ output_is_observed = \
+ result_node.args[0].name in observed_node_names_set
+ model._standalone_module_observed_input_idxs = \
+ standalone_module_observed_input_idxs
model._output_is_observed = output_is_observed
return model
@@ -566,19 +610,23 @@
observed._qconfig_map = self.qconfig_map
def restore_state(self, observed):
- assert is_observed_module(observed), 'incoming model must be produced by prepare_fx'
+ assert is_observed_module(observed), \
+ 'incoming model must be produced by prepare_fx'
self.activation_post_process_map = observed._activation_post_process_map
self.patterns = observed._patterns
self.qconfig_map = observed._qconfig_map
- def prepare(self, model, qconfig_dict, prepare_custom_config_dict=None, is_standalone_module=False):
- return self._prepare(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module)
+ def prepare(self, model, qconfig_dict, prepare_custom_config_dict=None,
+ is_standalone_module=False):
+ return self._prepare(
+ model, qconfig_dict, prepare_custom_config_dict,
+ is_standalone_module)
def _run_weight_observers(self, observed):
r''' Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
- Note that the observers of dynamic quant or weight only quant ops are run during
- the convert step.
+ Note that the observers of dynamic quant or weight only quant ops are
+ run during the convert step.
'''
for node in observed.graph.nodes:
if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
@@ -587,21 +635,23 @@
# node_arg is weight
weight_observer_nodes = collect_producer_nodes(node_arg)
if weight_observer_nodes is not None:
- weight_observer_module = graph_module_from_producer_nodes(
- observed, weight_observer_nodes)
+ weight_observer_module = \
+ graph_module_from_producer_nodes(
+ observed, weight_observer_nodes)
# run the weight observer
weight_observer_module()
return
- def _convert(self, model, debug=False, convert_custom_config_dict=None, is_standalone_module=False):
- """ standalone_module means it a submodule that is not inlined in parent module,
- and will be quantized separately as one unit.
+ def _convert(self, model, debug=False, convert_custom_config_dict=None,
+ is_standalone_module=False):
+ """ standalone_module means it a submodule that is not inlined in
+ parent module, and will be quantized separately as one unit.
For standalone module: the inputs will be quantized by parent module,
checks `_standalone_module_observed_input_idxs` of
input observed model and will treat these inputs as quantized
also will not dequantize the final output.
- Returns a quantized standalone module which accepts quantized input(if needed)
- and produces quantized output (if needed).
+ Returns a quantized standalone module which accepts quantized input
+ (if needed) and produces quantized output (if needed).
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
@@ -615,7 +665,8 @@
self.modules = dict(model.named_modules())
custom_module_classes = get_custom_module_class_keys(
- convert_custom_config_dict, "observed_to_quantized_custom_module_class")
+ convert_custom_config_dict,
+ "observed_to_quantized_custom_module_class")
matches = self._find_matches(
model.graph, self.modules, self.patterns,
custom_module_classes=custom_module_classes)
@@ -634,18 +685,20 @@
def load_non_quantized(n):
if n.name not in env:
assert n.name in quant_env, \
- 'trying to load float node but did not find node:' + n.name + \
- ' in quantized or non quantized environment, env: ' + str(env) + \
- ' quant_env:' + str(quant_env)
+ 'trying to load float node but did not find ' + \
+ 'node:' + n.name + \
+ ' in quantized or non quantized environment, env: ' + \
+ str(env) + ' quant_env:' + str(quant_env)
env[n.name] = Proxy(quant_env[n.name]).dequantize().node
return env[n.name]
def load_quantized(n):
if n.name not in quant_env:
assert n.name in env, \
- 'trying to load quantized node but did not find node:' + n.name + \
- ' in float environment:' + str(env)
- assert n.name in quants, 'did not find quant object for node:' + n.name
+ 'trying to load quantized node but did not find node:' + \
+ n.name + ' in float environment:' + str(env)
+ assert n.name in quants, \
+ 'did not find quant object for node:' + n.name
quant = quants[n.name][0]
quant_env[n.name] = quant.convert(self, env[n.name])
return quant_env[n.name]
@@ -661,21 +714,26 @@
def load_arg(quantized):
"""
Input: quantized, which can be None, list, boolean or tuple
- - if quantized is a list or tuple, then arg should be a list and the args with corresponding
- indexes will be quantized
- - if quantized is a boolean, then all args will be quantized/not quantized
- - if quantized is None, then we'll load the node as long as it exists
+ - if quantized is a list or tuple, then arg should be a list and
+ the args with corresponding indexes will be quantized
+ - if quantized is a boolean, then all args will be
+ quantized/not quantized
+ - if quantized is None, then we'll load the node as long as it
+ exists
- Output: fn which takes arg_or_args, and loads them from the corresponding
- environment depending on the value of quantized.
+ Output: fn which takes arg_or_args, and loads them from the
+ corresponding environment depending on the value of quantized.
"""
- assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)
+ assert quantized is None or \
+ isinstance(quantized, (tuple, list, bool)), type(quantized)
def load_arg_impl(arg_or_args):
if quantized is None:
return map_arg(arg_or_args, load_x)
if isinstance(quantized, bool):
- return map_arg(arg_or_args, load_quantized if quantized else load_non_quantized)
+ return map_arg(
+ arg_or_args,
+ load_quantized if quantized else load_non_quantized)
elif isinstance(quantized, (tuple, list)):
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
loaded_args = []
@@ -690,9 +748,10 @@
def is_quantized(node):
if isinstance(node, Node):
- assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
- # there might be nodes appearing in both environemnts, but quant_env will take
- # precedence
+ assert node.name in env or node.name in quant_env, \
+ 'Expecting node to be in the environment'
+ # there might be nodes appearing in both environemnts, but
+ # quant_env will take precedence
if node.name in quant_env:
return True
elif node.name in env:
@@ -704,17 +763,20 @@
elif not any(quantized):
return False
else:
- raise Exception("partially quantized inputs in list not handled yet")
+ raise Exception(
+ "partially quantized inputs in list not handled yet")
def is_output_quantized(node) -> bool:
""" Check if output node is quantized or not """
assert self.modules is not None
- if node.op == 'call_module' and is_observed_standalone_module(self.modules[node.target]):
+ if node.op == 'call_module' and \
+ is_observed_standalone_module(self.modules[node.target]):
quantized = bool(self.modules[node.target]._output_is_observed)
else:
quantized = True
- # Need to get correct quantized/non-quantized state for the output of CopyNode
+ # Need to get correct quantized/non-quantized state for the output
+ # of CopyNode
if type(obj) in [
CopyNode,
FixedQParamsOpQuantizeHandler
@@ -733,7 +795,8 @@
return quantized
def insert_quantize_node(node):
- """ Given a activation_post_process module call node, insert a quantize node"""
+ """ Given a activation_post_process module call node, insert a
+ quantize node"""
assert self.modules is not None
observer_module = self.modules[node.target]
prev_node = node.args[0]
@@ -744,9 +807,11 @@
# since we may need it when we insert prepack
# op for weight of linear, this will be removed
# later in a separate pass
- env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
+ env[node.name] = self.quantized_graph.node_copy(
+ node, load_non_quantized)
elif prev_node.name in quant_env:
- # if previous node is already quantized, we'll just remove the activation_post_process
+ # if previous node is already quantized, we'll just remove the
+ # activation_post_process
quant_env[node.name] = quant_env[prev_node.name]
else:
# replace activation post process with quantization ops
@@ -758,20 +823,25 @@
for node in model.graph.nodes:
if node.op == 'output':
if is_standalone_module:
- # result are kept quantized in the quantized standalone module
+ # result are kept quantized in the quantized standalone
+ # module
graph_output = map_arg(node.args[0], load_x)
else:
graph_output = map_arg(node.args[0], load_non_quantized)
self.quantized_graph.output(graph_output)
continue
- root_node, matched, matched_pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None))
+ root_node, matched, matched_pattern, obj, qconfig = \
+ matches.get(node.name, (None, None, None, None, None))
if root_node is node:
if qconfig is None:
- result = self.quantized_graph.node_copy(node, load_non_quantized)
+ result = self.quantized_graph.node_copy(
+ node, load_non_quantized)
quantized = False
else:
assert obj is not None
- result = obj.convert(self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict)
+ result = obj.convert(
+ self, node, load_arg, debug=debug,
+ convert_custom_config_dict=convert_custom_config_dict)
quantized = is_output_quantized(node)
if quantized:
@@ -783,16 +853,19 @@
continue
# handle activation post process calls
- if node.op == 'call_module' and is_activation_post_process(self.modules[node.target]):
+ if node.op == 'call_module' and \
+ is_activation_post_process(self.modules[node.target]):
insert_quantize_node(node)
elif (is_standalone_module and node.op == 'placeholder' and
graph_inputs.index(node.name) in
model._standalone_module_observed_input_idxs):
# the node is quantized in parent module
- quant_env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
+ quant_env[node.name] = \
+ self.quantized_graph.node_copy(node, load_non_quantized)
else:
# copy quantized or non-quantized node
- env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
+ env[node.name] = \
+ self.quantized_graph.node_copy(node, load_non_quantized)
# remove activation post process
act_post_process_removed_graph = Graph()
@@ -802,23 +875,25 @@
return map_arg(a, lambda node: env[node.name])
for node in self.quantized_graph.nodes:
if node.op == 'output':
- act_post_process_removed_graph.output(map_arg(node.args[0], load_arg))
+ act_post_process_removed_graph.output(
+ map_arg(node.args[0], load_arg))
continue
if node.op == 'call_module' and \
is_activation_post_process(self.modules[node.target]):
# remove activation post process node
env[node.name] = env[node.args[0].name]
else:
- env[node.name] = act_post_process_removed_graph.node_copy(node, load_arg)
+ env[node.name] = act_post_process_removed_graph.node_copy(
+ node, load_arg)
# removes qconfig and activation_post_process modules
_remove_qconfig(model)
model = GraphModule(model, act_post_process_removed_graph)
return model
- # Trace back from the weight node util we hit getattr, reconstruct the graph module
- # with the traced nodes and run the graph module to pack the weight. then replace
- # the original chain of ops with the packed weight.
+ # Trace back from the weight node util we hit getattr, reconstruct the
+ # graph module with the traced nodes and run the graph module to pack the
+ # weight. then replace the original chain of ops with the packed weight.
def _fold_weight(self, quantized):
packed_weights = dict()
# map from folded node name to the prepacked weight name
@@ -842,7 +917,8 @@
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
- get_new_packed_weight_name = get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
+ get_new_packed_weight_name = \
+ get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
quantized_root = quantized
quantized_graph = quantized.graph
for node in quantized_graph.nodes:
@@ -864,8 +940,10 @@
quantized = GraphModule(quantized_root, folded_graph)
return quantized
- def convert(self, model, debug=False, convert_custom_config_dict=None, is_standalone_module=False):
- quantized = self._convert(model, debug, convert_custom_config_dict, is_standalone_module)
+ def convert(self, model, debug=False, convert_custom_config_dict=None,
+ is_standalone_module=False):
+ quantized = self._convert(
+ model, debug, convert_custom_config_dict, is_standalone_module)
if not debug:
quantized = self._fold_weight(quantized)
return quantized
@@ -888,10 +966,12 @@
Outputs a map of
node_name ->
- (node, matched_values, matched_pattern, QuantizeHandler instance, qconfig)
+ (node, matched_values, matched_pattern, QuantizeHandler instance,
+ qconfig)
For example, {
- 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, <CopyNode instance>, QConfig(...)),
+ 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
+ <CopyNode instance>, QConfig(...)),
...
}
"""
@@ -925,7 +1005,9 @@
matched: List[Any] = []
record_match(pattern, node, matched)
for n in matched:
- match_map[n.name] = (node, matched, pattern, value(self, node), self.qconfig_map[n.name])
+ match_map[n.name] = (
+ node, matched, pattern, value(self, node),
+ self.qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
break
@@ -937,7 +1019,8 @@
type(self.modules[node.target]) in custom_module_classes:
custom_module_qconfig = self.qconfig_map[node.name]
match_map[node.name] = (
- node, [node], None, CustomModuleQuantizeHandler(self, node), custom_module_qconfig)
+ node, [node], None, CustomModuleQuantizeHandler(self, node),
+ custom_module_qconfig)
def is_standalone_module(node_target):
assert self.modules is not None
@@ -952,7 +1035,9 @@
# add node to matched nodes
custom_module_qconfig = self.qconfig_map[node.name]
match_map[node.name] = (
- node, [node], None, StandaloneModuleQuantizeHandler(self, node), custom_module_qconfig)
+ node, [node], None,
+ StandaloneModuleQuantizeHandler(self, node),
+ custom_module_qconfig)
return match_map
@@ -974,30 +1059,38 @@
def visit(node, matched_pattern, qconfig):
def visit_arg(arg):
is_weight = False
- if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
+ if isinstance(node, Node) and node.op == 'call_function' and \
+ node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
- if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]: # type: ignore
+ if arg is node_arg and i in \
+ WEIGHT_INDEX_DICT[node.target]: # type: ignore
is_weight = True
if qconfig is not None and \
(activation_is_statically_quantized(qconfig) or is_weight):
- act_post_process_ctr = qconfig.weight if is_weight else qconfig.activation
- quants[arg.name] = (DefaultQuantizeHandler(self, arg), qconfig, is_weight)
+ act_post_process_ctr = qconfig.weight if is_weight else \
+ qconfig.activation
+ quants[arg.name] = (
+ DefaultQuantizeHandler(self, arg), qconfig, is_weight)
# overwrite the constructor from qconfig
act_post_process_ctr = \
get_default_output_activation_post_process_map().get(
matched_pattern,
act_post_process_ctr)
- # overwrite previous activation post process constructor if necessary
- quants[arg.name] = (DefaultQuantizeHandler(self, arg), act_post_process_ctr)
+ # overwrite previous activation post process constructor if
+ # necessary
+ quants[arg.name] = (
+ DefaultQuantizeHandler(self, arg), act_post_process_ctr)
return visit_arg
for node in graph.nodes:
if node.name in matches:
- root_node, matched_nodes, matched_pattern, quantize_handler, qconfig = matches[node.name]
+ root_node, matched_nodes, matched_pattern, quantize_handler, \
+ qconfig = matches[node.name]
# don't attach observer/fake_quant for CopyNode
if isinstance(quantize_handler, CopyNode):
qconfig = None
- if root_node is node and input_output_observed(quantize_handler):
+ if root_node is node and \
+ input_output_observed(quantize_handler):
# matched_nodes[-1] is the first op in the sequence and
# matched_nodes[0] is the last op in the sequence
# inputs
@@ -1005,14 +1098,20 @@
# we only want to select QuantizeHandler object based
# on pattern for output, inputs will always use
# DefaultQuantizeHandler
- map_arg(matched_nodes[-1].args, visit(matched_nodes[-1], None, qconfig))
- map_arg(matched_nodes[-1].kwargs, visit(matched_nodes[-1], None, qconfig))
+ map_arg(matched_nodes[-1].args, visit(matched_nodes[-1],
+ None, qconfig))
+ map_arg(matched_nodes[-1].kwargs, visit(matched_nodes[-1],
+ None, qconfig))
# output
# we don't insert observer for output of standalone module
- if not isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
- # passing in matched_pattern here so that we can customize
- # activation_post_process constructor for output based on the pattern, e.g.
- # for sigmoid op we'll use default_affine_fixed_qparam_fake_quant
- map_arg(matched_nodes[0], visit(None, matched_pattern, qconfig))
+ if not isinstance(
+ quantize_handler, StandaloneModuleQuantizeHandler):
+ # passing in matched_pattern here so that we can
+ # customize activation_post_process constructor for
+ # output based on the pattern, e.g.
+ # for sigmoid op we'll use
+ # default_affine_fixed_qparam_fake_quant
+ map_arg(matched_nodes[0],
+ visit(None, matched_pattern, qconfig))
return quants