fx quant: move more functions to utils (#48908)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48908

No logic change, improving readability

Test Plan:
CI

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25363080

fbshipit-source-id: 1d73a875bd7abf671b544ebc835432fea5306dc3
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 73604bc..7da165b 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -51,6 +51,10 @@
     _parent_name,
     quantize_node,
     get_custom_module_class_keys,
+    get_new_attr_name_with_prefix,
+    collect_producer_nodes,
+    graph_module_from_producer_nodes,
+    assert_and_get_unique_device,
 )
 
 from .qconfig_utils import *
@@ -70,93 +74,6 @@
 # Helper Functions
 # ------------------------
 
-# Returns a function that can get a new attribute name for module with given
-# prefix, for example,
-# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
-# >> new_name = get_new_observer_name(module)
-# new_name will be an unused attribute name on module, e.g. `_observer_1`
-def get_new_attr_name_with_prefix(prefix: str) -> Callable:
-    def get_new_attr_name(module: torch.nn.Module):
-        def get_attr_name(i: int):
-            return prefix + str(i)
-        i = 0
-        attr_name = get_attr_name(i)
-        while hasattr(module, attr_name):
-            i += 1
-            attr_name = get_attr_name(i)
-        return attr_name
-    return get_new_attr_name
-
-def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
-    r''' Starting from a target node, trace back until we hit inpu or
-    getattr node. This is used to extract the chain of operators
-    starting from getattr to the target node, for example
-    def forward(self, x):
-      observed = self.observer(self.weight)
-      return F.linear(x, observed)
-    collect_producer_nodes(observed) will either return a list of nodes that
-    produces the observed node or None if we can't extract a self contained
-    graph without free variables(inputs of the forward function).
-    '''
-    nodes = [node]
-    frontier = [node]
-    while frontier:
-        node = frontier.pop()
-        all_args = list(node.args) + list(node.kwargs.values())
-        for arg in all_args:
-            if not isinstance(arg, Node):
-                continue
-            if arg.op == 'placeholder':
-                # hit input, can't fold in this case
-                return None
-            nodes.append(arg)
-            if not (arg.op == 'call_function' and arg.target == getattr):
-                frontier.append(arg)
-    return nodes
-
-def graph_module_from_producer_nodes(
-        root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
-    r''' Construct a graph module from extracted producer nodes
-    from `collect_producer_nodes` function
-    Args:
-      root: the root module for the original graph
-      producer_nodes: a list of nodes we use to construct the graph
-    Return:
-      A graph module constructed from the producer nodes
-    '''
-    assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
-    # since we traced back from node to getattrr
-    producer_nodes.reverse()
-    graph = Graph()
-    env: Dict[Any, Any] = {}
-
-    def load_arg(a):
-        return map_arg(a, lambda node: env[node])
-    for producer_node in producer_nodes:
-        env[producer_node] = graph.node_copy(producer_node, load_arg)
-    graph.output(load_arg(producer_nodes[-1]))
-    graph_module = GraphModule(root, graph)
-    return graph_module
-
-def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
-    """
-    Returns the unique device for a module, or None if no device is found.
-    Throws an error if multiple devices are detected.
-    """
-    devices = {p.device for p in module.parameters()} | \
-        {p.device for p in module.buffers()}
-    assert len(devices) <= 1, (
-        "prepare only works with cpu or single-device CUDA modules, "
-        "but got devices {}".format(devices)
-    )
-    device = next(iter(devices)) if len(devices) > 0 else None
-    return device
-
-def is_observed_standalone_module_node(
-        node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
-    return node.op == 'call_module' and \
-        is_observed_standalone_module(modules[node.target])  # type: ignore
-
 def insert_observer(
         node: Node, observer: torch.quantization.ObserverBase,
         model: torch.nn.Module,
@@ -764,8 +681,11 @@
                     quantized = False
                 else:
                     assert obj is not None
-                    is_standalone_module_node = is_observed_standalone_module_node(
-                        node, self.modules)
+                    is_standalone_module_node = (
+                        node.op == 'call_module' and
+                        is_observed_standalone_module(
+                            self.modules[node.target])  # type: ignore
+                    )
                     result = obj.convert(
                         self, node, load_arg, debug=debug,
                         convert_custom_config_dict=convert_custom_config_dict)
diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py
index a07cbc6..c1f8498 100644
--- a/torch/quantization/fx/utils.py
+++ b/torch/quantization/fx/utils.py
@@ -2,6 +2,15 @@
 import torch
 from ..utils import is_per_tensor, is_per_channel
 
+from torch.fx import GraphModule, map_arg
+
+from torch.fx.graph import (
+    Graph,
+    Node,
+)
+
+from typing import Callable, Optional, List, Dict, Any
+
 # turn foo.bar -> ['foo', 'bar']
 def _parent_name(target):
     r = target.rsplit('.', 1)
@@ -169,3 +178,85 @@
         return torch.ops.quantized.linear_prepack
     else:
         raise Exception("can't get linear prepack op for dtype:", dtype)
+
+# Returns a function that can get a new attribute name for module with given
+# prefix, for example,
+# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
+# >> new_name = get_new_observer_name(module)
+# new_name will be an unused attribute name on module, e.g. `_observer_1`
+def get_new_attr_name_with_prefix(prefix: str) -> Callable:
+    def get_new_attr_name(module: torch.nn.Module):
+        def get_attr_name(i: int):
+            return prefix + str(i)
+        i = 0
+        attr_name = get_attr_name(i)
+        while hasattr(module, attr_name):
+            i += 1
+            attr_name = get_attr_name(i)
+        return attr_name
+    return get_new_attr_name
+
+def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
+    r''' Starting from a target node, trace back until we hit inpu or
+    getattr node. This is used to extract the chain of operators
+    starting from getattr to the target node, for example
+    def forward(self, x):
+      observed = self.observer(self.weight)
+      return F.linear(x, observed)
+    collect_producer_nodes(observed) will either return a list of nodes that
+    produces the observed node or None if we can't extract a self contained
+    graph without free variables(inputs of the forward function).
+    '''
+    nodes = [node]
+    frontier = [node]
+    while frontier:
+        node = frontier.pop()
+        all_args = list(node.args) + list(node.kwargs.values())
+        for arg in all_args:
+            if not isinstance(arg, Node):
+                continue
+            if arg.op == 'placeholder':
+                # hit input, can't fold in this case
+                return None
+            nodes.append(arg)
+            if not (arg.op == 'call_function' and arg.target == getattr):
+                frontier.append(arg)
+    return nodes
+
+def graph_module_from_producer_nodes(
+        root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
+    r''' Construct a graph module from extracted producer nodes
+    from `collect_producer_nodes` function
+    Args:
+      root: the root module for the original graph
+      producer_nodes: a list of nodes we use to construct the graph
+    Return:
+      A graph module constructed from the producer nodes
+    '''
+    assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
+    # since we traced back from node to getattrr
+    producer_nodes.reverse()
+    graph = Graph()
+    env: Dict[Any, Any] = {}
+
+    def load_arg(a):
+        return map_arg(a, lambda node: env[node])
+    for producer_node in producer_nodes:
+        env[producer_node] = graph.node_copy(producer_node, load_arg)
+    graph.output(load_arg(producer_nodes[-1]))
+    graph_module = GraphModule(root, graph)
+    return graph_module
+
+def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
+    """
+    Returns the unique device for a module, or None if no device is found.
+    Throws an error if multiple devices are detected.
+    """
+    devices = {p.device for p in module.parameters()} | \
+        {p.device for p in module.buffers()}
+    assert len(devices) <= 1, (
+        "prepare only works with cpu or single-device CUDA modules, "
+        "but got devices {}".format(devices)
+    )
+    device = next(iter(devices)) if len(devices) > 0 else None
+    return device