[quant][graphmode][fx][refactor] insert_observer_for_output_of_the_node (#47784)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47784

Test Plan:
python test/test_quantization.py TestQuantizeFx

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D24900301

fbshipit-source-id: abaeae1b5747e517adeb0d50cec5998a8a3fc24d
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 7a37219..4a01c0b 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -425,6 +425,72 @@
                 self.modules[node.target] = observed_standalone_module
             return standalone_module_input_idxs
 
+        def insert_observer_for_output_of_the_node(
+                node,
+                quantize_handler,
+                qconfig,
+                standalone_module_input_idxs):
+            """ Insert observer/fake_quantize module for output of the observed module
+            if needed
+            """
+            # don't need to insert observer for output if activation does not
+            # need to be statically quantized
+            if activation_is_statically_quantized(qconfig):
+                if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and model.training:
+                    # we only insert fake quantize module in qat
+                    activation_post_process_ctr = \
+                        get_default_output_activation_post_process_map().get(pattern, None)
+                    assert activation_post_process_ctr is not None, \
+                        "activation_post_process constructor not provided for " + \
+                        "pattern:" + str(pattern)
+                    insert_observer(node, activation_post_process_ctr())
+                elif (isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) and
+                      not model.training) or isinstance(quantize_handler, CopyNode):
+                    # inserting observers for output of observed module, or mark the output
+                    # as observed
+                    assert node.op in [
+                        'call_module',
+                        'call_function',
+                        'call_method'], \
+                        'CopyNode of type ' + node.op + ' is not handled'
+
+                    def is_observed(input_arg):
+                        if isinstance(input_arg, Node):
+                            return input_arg.name in observed_node_names_set
+                        elif isinstance(input_arg, list):
+                            return all(map(is_observed, input_arg))
+                    # propagate observed property from input
+                    if is_observed(node.args[0]):
+                        observed_node_names_set.add(node.name)
+                elif ((isinstance(quantize_handler, Add) or isinstance(quantize_handler, Mul)) and
+                      quantize_handler.num_node_args == 1):
+                    input_node = matched_nodes[-1]  # first node in the sequence
+
+                    def input_is_observed(arg):
+                        return isinstance(arg, Node) and arg.name in observed_node_names_set
+                    # This is checking if one of the argument of add/mul
+                    # is an observed node
+                    # If both of the inputs are number,
+                    # we will not consider the output to be observed
+                    if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
+                        observed_node_names_set.add(node.name)
+                elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
+                    assert node.op == 'call_module'
+                    output_is_observed = self.modules[node.target]._output_is_observed
+                    if output_is_observed:
+                        observed_node_names_set.add(node.name)
+                elif quantize_handler.all_node_args:
+                    # observer for outputs
+                    new_observer = qconfig.activation()
+                    insert_observer(node, new_observer)
+
+            # insert observer for input of standalone module
+            if standalone_module_input_idxs is not None:
+                for idx in standalone_module_input_idxs:
+                    if node.args[idx].name not in observed_node_names_set:
+                        new_observer = qconfig.activation()
+                        insert_observer(node.args[idx], new_observer)
+
         result_node : Optional[Node] = None
         for node in model.graph.nodes:
             if node.op == 'output':
@@ -442,63 +508,8 @@
                 # index for input of custom module that needs to be observed in parent
                 if qconfig is not None:
                     standalone_module_input_idxs = insert_observer_for_special_module(obj)
-
-                    # don't need to insert observer for output if activation does not
-                    # need to be statically quantized
-                    if activation_is_statically_quantized(qconfig):
-                        if isinstance(obj, FixedQParamsOpQuantizeHandler) and model.training:
-                            # we only insert fake quantize module in qat
-                            activation_post_process_ctr = \
-                                get_default_output_activation_post_process_map().get(pattern, None)
-                            assert activation_post_process_ctr is not None, \
-                                "activation_post_process constructor not provided for " + \
-                                "pattern:" + str(pattern)
-                            insert_observer(node, activation_post_process_ctr())
-                        elif (isinstance(obj, FixedQParamsOpQuantizeHandler) and
-                              not model.training) or isinstance(obj, CopyNode):
-                            # inserting observers for output of observed module, or mark the output
-                            # as observed
-                            assert node.op in [
-                                'call_module',
-                                'call_function',
-                                'call_method'], \
-                                'CopyNode of type ' + node.op + ' is not handled'
-
-                            def is_observed(input_arg):
-                                if isinstance(input_arg, Node):
-                                    return input_arg.name in observed_node_names_set
-                                elif isinstance(input_arg, list):
-                                    return all(map(is_observed, input_arg))
-                            # propagate observed property from input
-                            if is_observed(node.args[0]):
-                                observed_node_names_set.add(node.name)
-                        elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1:
-                            input_node = matched_nodes[-1]  # first node in the sequence
-
-                            def input_is_observed(arg):
-                                return isinstance(arg, Node) and arg.name in observed_node_names_set
-                            # This is checking if one of the argument of add/mul
-                            # is an observed node
-                            # If both of the inputs are number,
-                            # we will not consider the output to be observed
-                            if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
-                                observed_node_names_set.add(node.name)
-                        elif isinstance(obj, StandaloneModuleQuantizeHandler):
-                            assert node.op == 'call_module'
-                            output_is_observed = self.modules[node.target]._output_is_observed
-                            if output_is_observed:
-                                observed_node_names_set.add(node.name)
-                        elif obj.all_node_args:
-                            # observer for outputs
-                            new_observer = qconfig.activation()
-                            insert_observer(node, new_observer)
-
-                    # insert observer for input of standalone module
-                    if standalone_module_input_idxs is not None:
-                        for idx in standalone_module_input_idxs:
-                            if node.args[idx].name not in observed_node_names_set:
-                                new_observer = qconfig.activation()
-                                insert_observer(node.args[idx], new_observer)
+                    insert_observer_for_output_of_the_node(
+                        node, obj, qconfig, standalone_module_input_idxs)
             else:
                 env[node.name] = observed_graph.node_copy(node, load_arg)