[quant][be] Move some helper functions to the top level to reduce function length (#89246)

Summary:
att

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89246
Approved by: https://github.com/vkuzo
diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py
index 0c1249b..ca6ae61 100644
--- a/torch/ao/quantization/fx/convert.py
+++ b/torch/ao/quantization/fx/convert.py
@@ -88,6 +88,83 @@
     "run_weight_observers",
 ]
 
+def _replace_observer_with_quantize_dequantize_node(
+        model: torch.nn.Module,
+        graph: Graph,
+        node: Node,
+        modules: Dict[str, torch.nn.Module],
+        node_name_to_scope: Dict[str, Tuple[str, type]],
+        node_name_to_qconfig: Dict[str, QConfigAny],
+        is_decomposed: bool) -> None:
+    """ Replace activation_post_process module call node with quantize and
+    dequantize node
+
+    Before:
+    ... -> observer_0(x) -> ...
+    After:
+    ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
+    """
+    assert modules is not None
+    assert isinstance(node.target, str)
+    module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
+    observer_module = modules[node.target]
+    maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed)
+    # Skip replacing observers to quant/dequant nodes if the qconfigs of all
+    # consumers and producers of this observer are None
+    skip_replacement = all([
+        has_none_qconfig(n, node_name_to_qconfig) for n in
+        list(node.args) + list(node.users.keys())])
+    if skip_replacement or maybe_quantize_node_info is None:
+        # didn't find correponding quantize op and info for the observer_module
+        # so we just remove the observer
+        with graph.inserting_before(node):
+            node.replace_all_uses_with(node.args[0])
+            graph.erase_node(node)
+    else:
+        # otherwise, we can convert the observer moduel call to quantize/dequantize node
+        node_type, quantize_op, qparams = maybe_quantize_node_info
+        # replace observer node with quant - dequant node
+        with graph.inserting_before(node):
+            input_node = node.args[0]
+            quantize_op_inputs = [input_node]
+            for key, value in qparams.items():
+                # TODO: we can add the information of whether a value needs to
+                # be registered as an attribute in qparams dict itself
+                if key in ['_scale_', '_zero_point_']:
+                    # For scale and zero_point values we register them as buffers in the root module.
+                    # TODO: maybe need more complex attr name here
+                    qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
+                    quantize_op_inputs.append(qparam_node)
+                else:
+                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
+                    quantize_op_inputs.append(value)
+
+            quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
+            if is_decomposed:
+                # use the same qparams from quantize op
+                dq_inputs = [quantized_node] + quantize_op_inputs[1:]
+                dequantized_node = graph.call_function(
+                    torch.ops.quantized_decomposed.dequantize_per_tensor,
+                    tuple(dq_inputs),
+                    {}
+                )
+            else:
+                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+            node.replace_all_uses_with(dequantized_node)
+            graph.erase_node(node)
+
+# this is a temporary hack for custom module, we may want to implement
+# this properly after the custom module class design is finalized
+# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
+# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
+# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
+def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph):
+    call_custom_module_node = node.args[0]
+    assert isinstance(call_custom_module_node, Node), \
+        f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
+    node.replace_all_uses_with(call_custom_module_node)
+    graph.erase_node(node)
+    insert_dequantize_node(call_custom_module_node, graph)
 
 def restore_state(
         observed: torch.nn.Module
@@ -599,85 +676,6 @@
         if node.op == 'placeholder':
             graph_inputs.append(node.name)
 
-    # TODO: move this outside of this function
-    def replace_observer_with_quantize_dequantize_node(
-            model: torch.nn.Module,
-            graph: Graph,
-            node: Node,
-            modules: Dict[str, torch.nn.Module],
-            node_name_to_scope: Dict[str, Tuple[str, type]],
-            node_name_to_qconfig: Dict[str, QConfigAny],
-            is_decomposed: bool) -> None:
-        """ Replace activation_post_process module call node with quantize and
-        dequantize node
-
-        Before:
-        ... -> observer_0(x) -> ...
-        After:
-        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
-        """
-        assert modules is not None
-        assert isinstance(node.target, str)
-        module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
-        observer_module = modules[node.target]
-        maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed)
-        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
-        # consumers and producers of this observer are None
-        skip_replacement = all([
-            has_none_qconfig(n, node_name_to_qconfig) for n in
-            list(node.args) + list(node.users.keys())])
-        if skip_replacement or maybe_quantize_node_info is None:
-            # didn't find correponding quantize op and info for the observer_module
-            # so we just remove the observer
-            with graph.inserting_before(node):
-                node.replace_all_uses_with(node.args[0])
-                graph.erase_node(node)
-        else:
-            # otherwise, we can convert the observer moduel call to quantize/dequantize node
-            node_type, quantize_op, qparams = maybe_quantize_node_info
-            # replace observer node with quant - dequant node
-            with graph.inserting_before(node):
-                input_node = node.args[0]
-                quantize_op_inputs = [input_node]
-                for key, value in qparams.items():
-                    # TODO: we can add the information of whether a value needs to
-                    # be registered as an attribute in qparams dict itself
-                    if key in ['_scale_', '_zero_point_']:
-                        # For scale and zero_point values we register them as buffers in the root module.
-                        # TODO: maybe need more complex attr name here
-                        qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
-                        quantize_op_inputs.append(qparam_node)
-                    else:
-                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
-                        quantize_op_inputs.append(value)
-
-                quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
-                if is_decomposed:
-                    # use the same qparams from quantize op
-                    dq_inputs = [quantized_node] + quantize_op_inputs[1:]
-                    dequantized_node = graph.call_function(
-                        torch.ops.quantized_decomposed.dequantize_per_tensor,
-                        tuple(dq_inputs),
-                        {}
-                    )
-                else:
-                    dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
-                node.replace_all_uses_with(dequantized_node)
-                graph.erase_node(node)
-
-    # this is a temporary hack for custom module, we may want to implement
-    # this properly after the custom module class design is finalized
-    # TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
-    # after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
-    # after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
-    def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph):
-        call_custom_module_node = node.args[0]
-        assert isinstance(call_custom_module_node, Node), \
-            f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
-        node.replace_all_uses_with(call_custom_module_node)
-        graph.erase_node(node)
-        insert_dequantize_node(call_custom_module_node, graph)
-
     # additional state to override inputs to be quantized, if specified
     # by the user
     placeholder_node_seen_cnt = 0
@@ -728,13 +726,13 @@
             if _is_activation_post_process(mod):
                 observed_node = node.args[0]
                 if observed_node in statically_quantized_custom_module_nodes:
-                    replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
+                    _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
                 else:
-                    replace_observer_with_quantize_dequantize_node(
+                    _replace_observer_with_quantize_dequantize_node(
                         model, model.graph, node, modules, node_name_to_scope,
                         node_name_to_qconfig, is_decomposed)
             elif isinstance(mod, DeQuantStub):
-                replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
+                _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
             elif is_observed_standalone_module(mod):
                 convert_standalone_module(
                     node, modules, model, is_reference, backend_config)