[quant][graphmode][fx][refactor] insert_observer_for_output_of_the_node (#47784)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47784
Test Plan:
python test/test_quantization.py TestQuantizeFx
Imported from OSS
Reviewed By: vkuzo
Differential Revision: D24900301
fbshipit-source-id: abaeae1b5747e517adeb0d50cec5998a8a3fc24d
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 7a37219..4a01c0b 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -425,6 +425,72 @@
self.modules[node.target] = observed_standalone_module
return standalone_module_input_idxs
+ def insert_observer_for_output_of_the_node(
+ node,
+ quantize_handler,
+ qconfig,
+ standalone_module_input_idxs):
+ """ Insert observer/fake_quantize module for output of the observed module
+ if needed
+ """
+ # don't need to insert observer for output if activation does not
+ # need to be statically quantized
+ if activation_is_statically_quantized(qconfig):
+ if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and model.training:
+ # we only insert fake quantize module in qat
+ activation_post_process_ctr = \
+ get_default_output_activation_post_process_map().get(pattern, None)
+ assert activation_post_process_ctr is not None, \
+ "activation_post_process constructor not provided for " + \
+ "pattern:" + str(pattern)
+ insert_observer(node, activation_post_process_ctr())
+ elif (isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and
+ not model.training) or isinstance(quantize_handler, CopyNode):
+ # inserting observers for output of observed module, or mark the output
+ # as observed
+ assert node.op in [
+ 'call_module',
+ 'call_function',
+ 'call_method'], \
+ 'CopyNode of type ' + node.op + ' is not handled'
+
+ def is_observed(input_arg):
+ if isinstance(input_arg, Node):
+ return input_arg.name in observed_node_names_set
+ elif isinstance(input_arg, list):
+ return all(map(is_observed, input_arg))
+ # propagate observed property from input
+ if is_observed(node.args[0]):
+ observed_node_names_set.add(node.name)
+ elif ((isinstance(quantize_handler, Add) or isinstance(quantize_handler, Mul)) and
+ quantize_handler.num_node_args == 1):
+ input_node = matched_nodes[-1] # first node in the sequence
+
+ def input_is_observed(arg):
+ return isinstance(arg, Node) and arg.name in observed_node_names_set
+ # This is checking if one of the argument of add/mul
+ # is an observed node
+ # If both of the inputs are number,
+ # we will not consider the output to be observed
+ if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
+ observed_node_names_set.add(node.name)
+ elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
+ assert node.op == 'call_module'
+ output_is_observed = self.modules[node.target]._output_is_observed
+ if output_is_observed:
+ observed_node_names_set.add(node.name)
+ elif quantize_handler.all_node_args:
+ # observer for outputs
+ new_observer = qconfig.activation()
+ insert_observer(node, new_observer)
+
+ # insert observer for input of standalone module
+ if standalone_module_input_idxs is not None:
+ for idx in standalone_module_input_idxs:
+ if node.args[idx].name not in observed_node_names_set:
+ new_observer = qconfig.activation()
+ insert_observer(node.args[idx], new_observer)
+
result_node : Optional[Node] = None
for node in model.graph.nodes:
if node.op == 'output':
@@ -442,63 +508,8 @@
# index for input of custom module that needs to be observed in parent
if qconfig is not None:
standalone_module_input_idxs = insert_observer_for_special_module(obj)
-
- # don't need to insert observer for output if activation does not
- # need to be statically quantized
- if activation_is_statically_quantized(qconfig):
- if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training:
- # we only insert fake quantize module in qat
- activation_post_process_ctr = \
- get_default_output_activation_post_process_map().get(pattern, None)
- assert activation_post_process_ctr is not None, \
- "activation_post_process constructor not provided for " + \
- "pattern:" + str(pattern)
- insert_observer(node, activation_post_process_ctr())
- elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and
- not model.training) or isinstance(obj, CopyNode):
- # inserting observers for output of observed module, or mark the output
- # as observed
- assert node.op in [
- 'call_module',
- 'call_function',
- 'call_method'], \
- 'CopyNode of type ' + node.op + ' is not handled'
-
- def is_observed(input_arg):
- if isinstance(input_arg, Node):
- return input_arg.name in observed_node_names_set
- elif isinstance(input_arg, list):
- return all(map(is_observed, input_arg))
- # propagate observed property from input
- if is_observed(node.args[0]):
- observed_node_names_set.add(node.name)
- elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1:
- input_node = matched_nodes[-1] # first node in the sequence
-
- def input_is_observed(arg):
- return isinstance(arg, Node) and arg.name in observed_node_names_set
- # This is checking if one of the argument of add/mul
- # is an observed node
- # If both of the inputs are number,
- # we will not consider the output to be observed
- if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
- observed_node_names_set.add(node.name)
- elif isinstance(obj, StandaloneModuleQuantizeHandler):
- assert node.op == 'call_module'
- output_is_observed = self.modules[node.target]._output_is_observed
- if output_is_observed:
- observed_node_names_set.add(node.name)
- elif obj.all_node_args:
- # observer for outputs
- new_observer = qconfig.activation()
- insert_observer(node, new_observer)
-
- # insert observer for input of standalone module
- if standalone_module_input_idxs is not None:
- for idx in standalone_module_input_idxs:
- if node.args[idx].name not in observed_node_names_set:
- new_observer = qconfig.activation()
- insert_observer(node.args[idx], new_observer)
+ insert_observer_for_output_of_the_node(
+ node, obj, qconfig, standalone_module_input_idxs)
else:
env[node.name] = observed_graph.node_copy(node, load_arg)