fx quant: move more functions to utils (#48908)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48908
No logic change, improving readability
Test Plan:
CI
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D25363080
fbshipit-source-id: 1d73a875bd7abf671b544ebc835432fea5306dc3
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 73604bc..7da165b 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -51,6 +51,10 @@
_parent_name,
quantize_node,
get_custom_module_class_keys,
+ get_new_attr_name_with_prefix,
+ collect_producer_nodes,
+ graph_module_from_producer_nodes,
+ assert_and_get_unique_device,
)
from .qconfig_utils import *
@@ -70,93 +74,6 @@
# Helper Functions
# ------------------------
-# Returns a function that can get a new attribute name for module with given
-# prefix, for example,
-# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
-# >> new_name = get_new_observer_name(module)
-# new_name will be an unused attribute name on module, e.g. `_observer_1`
-def get_new_attr_name_with_prefix(prefix: str) -> Callable:
- def get_new_attr_name(module: torch.nn.Module):
- def get_attr_name(i: int):
- return prefix + str(i)
- i = 0
- attr_name = get_attr_name(i)
- while hasattr(module, attr_name):
- i += 1
- attr_name = get_attr_name(i)
- return attr_name
- return get_new_attr_name
-
-def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
- r''' Starting from a target node, trace back until we hit inpu or
- getattr node. This is used to extract the chain of operators
- starting from getattr to the target node, for example
- def forward(self, x):
- observed = self.observer(self.weight)
- return F.linear(x, observed)
- collect_producer_nodes(observed) will either return a list of nodes that
- produces the observed node or None if we can't extract a self contained
- graph without free variables(inputs of the forward function).
- '''
- nodes = [node]
- frontier = [node]
- while frontier:
- node = frontier.pop()
- all_args = list(node.args) + list(node.kwargs.values())
- for arg in all_args:
- if not isinstance(arg, Node):
- continue
- if arg.op == 'placeholder':
- # hit input, can't fold in this case
- return None
- nodes.append(arg)
- if not (arg.op == 'call_function' and arg.target == getattr):
- frontier.append(arg)
- return nodes
-
-def graph_module_from_producer_nodes(
- root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
- r''' Construct a graph module from extracted producer nodes
- from `collect_producer_nodes` function
- Args:
- root: the root module for the original graph
- producer_nodes: a list of nodes we use to construct the graph
- Return:
- A graph module constructed from the producer nodes
- '''
- assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
- # since we traced back from node to getattrr
- producer_nodes.reverse()
- graph = Graph()
- env: Dict[Any, Any] = {}
-
- def load_arg(a):
- return map_arg(a, lambda node: env[node])
- for producer_node in producer_nodes:
- env[producer_node] = graph.node_copy(producer_node, load_arg)
- graph.output(load_arg(producer_nodes[-1]))
- graph_module = GraphModule(root, graph)
- return graph_module
-
-def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
- """
- Returns the unique device for a module, or None if no device is found.
- Throws an error if multiple devices are detected.
- """
- devices = {p.device for p in module.parameters()} | \
- {p.device for p in module.buffers()}
- assert len(devices) <= 1, (
- "prepare only works with cpu or single-device CUDA modules, "
- "but got devices {}".format(devices)
- )
- device = next(iter(devices)) if len(devices) > 0 else None
- return device
-
-def is_observed_standalone_module_node(
- node: Node, modules: Dict[str, torch.nn.Module]) -> bool:
- return node.op == 'call_module' and \
- is_observed_standalone_module(modules[node.target]) # type: ignore
-
def insert_observer(
node: Node, observer: torch.quantization.ObserverBase,
model: torch.nn.Module,
@@ -764,8 +681,11 @@
quantized = False
else:
assert obj is not None
- is_standalone_module_node = is_observed_standalone_module_node(
- node, self.modules)
+ is_standalone_module_node = (
+ node.op == 'call_module' and
+ is_observed_standalone_module(
+ self.modules[node.target]) # type: ignore
+ )
result = obj.convert(
self, node, load_arg, debug=debug,
convert_custom_config_dict=convert_custom_config_dict)
diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py
index a07cbc6..c1f8498 100644
--- a/torch/quantization/fx/utils.py
+++ b/torch/quantization/fx/utils.py
@@ -2,6 +2,15 @@
import torch
from ..utils import is_per_tensor, is_per_channel
+from torch.fx import GraphModule, map_arg
+
+from torch.fx.graph import (
+ Graph,
+ Node,
+)
+
+from typing import Callable, Optional, List, Dict, Any
+
# turn foo.bar -> ['foo', 'bar']
def _parent_name(target):
r = target.rsplit('.', 1)
@@ -169,3 +178,85 @@
return torch.ops.quantized.linear_prepack
else:
raise Exception("can't get linear prepack op for dtype:", dtype)
+
+# Returns a function that can get a new attribute name for module with given
+# prefix, for example,
+# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
+# >> new_name = get_new_observer_name(module)
+# new_name will be an unused attribute name on module, e.g. `_observer_1`
+def get_new_attr_name_with_prefix(prefix: str) -> Callable:
+ def get_new_attr_name(module: torch.nn.Module):
+ def get_attr_name(i: int):
+ return prefix + str(i)
+ i = 0
+ attr_name = get_attr_name(i)
+ while hasattr(module, attr_name):
+ i += 1
+ attr_name = get_attr_name(i)
+ return attr_name
+ return get_new_attr_name
+
+def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
+ r''' Starting from a target node, trace back until we hit inpu or
+ getattr node. This is used to extract the chain of operators
+ starting from getattr to the target node, for example
+ def forward(self, x):
+ observed = self.observer(self.weight)
+ return F.linear(x, observed)
+ collect_producer_nodes(observed) will either return a list of nodes that
+ produces the observed node or None if we can't extract a self contained
+ graph without free variables(inputs of the forward function).
+ '''
+ nodes = [node]
+ frontier = [node]
+ while frontier:
+ node = frontier.pop()
+ all_args = list(node.args) + list(node.kwargs.values())
+ for arg in all_args:
+ if not isinstance(arg, Node):
+ continue
+ if arg.op == 'placeholder':
+ # hit input, can't fold in this case
+ return None
+ nodes.append(arg)
+ if not (arg.op == 'call_function' and arg.target == getattr):
+ frontier.append(arg)
+ return nodes
+
+def graph_module_from_producer_nodes(
+ root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
+ r''' Construct a graph module from extracted producer nodes
+ from `collect_producer_nodes` function
+ Args:
+ root: the root module for the original graph
+ producer_nodes: a list of nodes we use to construct the graph
+ Return:
+ A graph module constructed from the producer nodes
+ '''
+ assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
+ # since we traced back from node to getattrr
+ producer_nodes.reverse()
+ graph = Graph()
+ env: Dict[Any, Any] = {}
+
+ def load_arg(a):
+ return map_arg(a, lambda node: env[node])
+ for producer_node in producer_nodes:
+ env[producer_node] = graph.node_copy(producer_node, load_arg)
+ graph.output(load_arg(producer_nodes[-1]))
+ graph_module = GraphModule(root, graph)
+ return graph_module
+
+def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
+ """
+ Returns the unique device for a module, or None if no device is found.
+ Throws an error if multiple devices are detected.
+ """
+ devices = {p.device for p in module.parameters()} | \
+ {p.device for p in module.buffers()}
+ assert len(devices) <= 1, (
+ "prepare only works with cpu or single-device CUDA modules, "
+ "but got devices {}".format(devices)
+ )
+ device = next(iter(devices)) if len(devices) > 0 else None
+ return device