[quant][refactor] Refactor find_matches for easier future extension (#74878)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74878
Previously we record the matched node as a list of nodes: `List[Node]`, this does not generalize
to a graph, which is needed for future use cases, in this PR we changed the recorded node as
NodePattern instead, currently defined as
```
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
```
but can be more general.
This will allow us to support more general patterns with backend_config_dict api, and is also needed
for BinaryOpQuantizeHandler refactor
Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
Imported from OSS
Reviewed By: vkuzo
Differential Revision: D35203616
fbshipit-source-id: f4bf5b056cfc0955455eea9c2bf1ac9f6dde3974
(cherry picked from commit b290c047e1861bbb62fb1bb576761e801b210220)
diff --git a/torch/ao/quantization/fx/backend_config/quantize_handler.py b/torch/ao/quantization/fx/backend_config/quantize_handler.py
index fe932e3..1026844 100644
--- a/torch/ao/quantization/fx/backend_config/quantize_handler.py
+++ b/torch/ao/quantization/fx/backend_config/quantize_handler.py
@@ -1,14 +1,18 @@
import torch
-from typing import Dict
-from torch.fx.graph import Node
+from typing import Dict, Callable
from .observation_type import ObservationType
from ..quantization_patterns import QuantizeHandler
+from ..quantization_types import NodePattern
def get_quantize_handler_cls(observation_type, dtype_configs):
class ConfigurableQuantizeHandler(QuantizeHandler):
- def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
- super().__init__(node, modules)
+ def __init__(
+ self,
+ node_pattern: NodePattern,
+ modules: Dict[str, torch.nn.Module],
+ root_node_getter: Callable = None):
+ super().__init__(node_pattern, modules, root_node_getter)
self.observation_type = observation_type
self.dtype_configs = dtype_configs
diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py
index 6dffcc7..ee28a2a 100644
--- a/torch/ao/quantization/fx/fuse.py
+++ b/torch/ao/quantization/fx/fuse.py
@@ -111,6 +111,7 @@
# a map from node to the matched subpattern
node_to_subpattern: Dict[Node, Any] = {}
+ # TODO: dedup with quantization matching function in match_utils.py
def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
if isinstance(pattern, tuple):
s, *args = pattern
diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py
index 876bc39..38fb6e8 100644
--- a/torch/ao/quantization/fx/match_utils.py
+++ b/torch/ao/quantization/fx/match_utils.py
@@ -76,6 +76,7 @@
graph: Graph,
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
+ root_node_getter_mapping: Dict[Pattern, Callable],
qconfig_map: Dict[str, QConfigAny],
standalone_module_names: List[str] = None,
standalone_module_classes: List[Callable] = None,
@@ -114,29 +115,80 @@
match_map: Dict[str, MatchResult] = {}
all_matched : Set[str] = set()
- def record_match(pattern, node, matched):
+ def _recursive_record_node_in_match_map(
+ last_node,
+ match_map,
+ node_pattern,
+ matched_node_pattern,
+ pattern,
+ match_value,
+ qconfig):
+ if isinstance(node_pattern, Node):
+ match_map[node_pattern.name] = (
+ last_node, matched_node_pattern, pattern, match_value, qconfig)
+ else:
+ for n in node_pattern:
+ _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value, qconfig)
+
+ # TODO: 1. merge with fuse matcher 2. document the code
+ def record_match(
+ pattern,
+ node,
+ last_node,
+ matched_node_pattern,
+ match_map):
if isinstance(pattern, tuple):
s, *args = pattern
- record_match(s, node, matched)
+ current_node_pattern: List[Node] = []
+ record_match(
+ s,
+ node,
+ last_node,
+ matched_node_pattern,
+ match_map)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
- record_match(subpattern, arg, matched)
+ record_match(
+ subpattern,
+ arg,
+ node,
+ current_node_pattern,
+ match_map)
+ if len(current_node_pattern) > 1:
+ matched_node_pattern.append(tuple(current_node_pattern))
+ else:
+ matched_node_pattern.append(current_node_pattern[0])
else:
- matched.append(node)
+ matched_node_pattern.append(node)
- cache_for_no_tensor_check: Dict[Node, bool] = dict()
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
- for pattern, value in patterns.items():
- if is_match(modules, node, pattern):
- matched: List[Any] = []
- record_match(pattern, node, matched)
- for n in matched:
- match_map[n.name] = (
- node, matched, pattern, value(node, modules), # type: ignore[operator]
- qconfig_map[n.name])
- all_matched.add(n.name)
- # break after finding the first match
+ for pattern, quantize_handler_cls in patterns.items():
+ root_node_getter = root_node_getter_mapping.get(pattern, None)
+ if is_match(modules, node, pattern) and node.name not in match_map:
+ matched_node_pattern: List[Node] = []
+ record_match(
+ pattern,
+ node,
+ node,
+ matched_node_pattern,
+ match_map)
+ quantize_handler = quantize_handler_cls( # type: ignore[operator]
+ matched_node_pattern,
+ modules,
+ root_node_getter)
+ last_node = node
+ # record the match for all nodes in the pattern
+ _recursive_record_node_in_match_map(
+ last_node,
+ match_map,
+ # we need to record all nodes in the matched pattern in the match_map
+ matched_node_pattern,
+ # this is a part of the value corresponding to the node
+ matched_node_pattern,
+ pattern,
+ quantize_handler,
+ qconfig_map[node.name])
break
# add custom module instances to the match result
@@ -146,7 +198,7 @@
type(modules[node.target]) in custom_module_classes:
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
- node, [node], None, CustomModuleQuantizeHandler(node, modules),
+ node, node, None, CustomModuleQuantizeHandler(node, modules),
custom_module_qconfig)
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
@@ -164,7 +216,7 @@
# add node to matched nodes
custom_module_qconfig = qconfig_map[node.name]
match_map[node.name] = (
- node, [node], None,
+ node, node, None,
StandaloneModuleQuantizeHandler(node, modules),
custom_module_qconfig)
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index e23446d..72538e2 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -34,7 +34,10 @@
StandaloneModuleQuantizeHandler,
)
-from .quantization_types import Pattern
+from .quantization_types import (
+ Pattern,
+ NodePattern
+)
from ._equalize import (
is_equalization_observer,
@@ -90,6 +93,7 @@
get_pattern_to_input_type_to_index,
get_module_to_qat_module,
get_native_quant_patterns,
+ get_fusion_pattern_to_root_node_getter,
)
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set
@@ -180,7 +184,7 @@
def is_pattern_dtype_config_supported_by_backend(
pattern: Optional[Pattern],
- matched_nodes: Optional[List[Node]],
+ matched_node_pattern: Optional[NodePattern],
node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
backend_config_dict: Optional[Dict[str, Any]]
) -> bool:
@@ -189,14 +193,15 @@
"""
if backend_config_dict is None or pattern is None:
return True
- assert matched_nodes is not None and len(matched_nodes) >= 1
+ assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config_dict)
dtype_configs: List[Dict[str, torch.dtype]] = pattern_to_dtype_configs.get(pattern, [])
- # TODO: this only checks one input and one output, need to generalize to multiple
+ # TODO: this only works for one input and one output patterns, need to generalize to multiple
# inputs/output
- input_node = matched_nodes[-1]
- output_node = matched_nodes[0]
+ root_node = _default_root_node_getter(matched_node_pattern)
+ input_node = root_node
+ output_node = matched_node_pattern[0]
for dtype_config in dtype_configs:
# check if arg dtype are supported
supported = True
@@ -247,6 +252,19 @@
module_to_qat_module: Dict[Callable, Callable]) -> None:
convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
+def add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
+ if isinstance(matched_node_pattern, Node):
+ s.add(matched_node_pattern.name)
+ elif isinstance(matched_node_pattern, (list, tuple)):
+ for maybe_node in matched_node_pattern:
+ add_matched_node_name_to_set(maybe_node, s)
+
+# this is temporary, will be removed soon
+def _default_root_node_getter(node_pattern):
+ while not isinstance(node_pattern, Node):
+ node_pattern = node_pattern[-1]
+ return node_pattern
+
# TODO: remove observed_op, looks like it's not used
def insert_observer(
node: Node,
@@ -686,7 +704,7 @@
If `node` does not need an output observer, returns None.
"""
- root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+ root_node, _, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
if qhandler is None:
@@ -839,7 +857,7 @@
node_name_to_target_dtype[node.name]["input_activation_dtype"] = target_dtype
node_name_to_target_dtype[node.name]["output_activation_dtype"] = target_dtype
# if this is a copy node, propagate to first arg
- root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+ root_node, _, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
if qhandler is not None and qhandler.is_general_tensor_value_op():
prev_node = node.args[0]
@@ -1078,7 +1096,7 @@
# other nodes output dtype is specified by the qconfig
modules = dict(model.named_modules(remove_duplicate=False))
for node in model.graph.nodes:
- root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+ root_node, _, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
node_name_to_target_dtype[node.name] = get_target_activation_dtype_for_node(
node, qconfig, inputs_seen_counter, outputs_seen_counter,
@@ -1119,7 +1137,7 @@
elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
# check for matches
- root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+ last_node, matched_node_pattern, pattern, qhandler, qconfig = matches.get(
node.name, (None, None, None, None, None))
equalization_qconfig = equalization_config_map.get(node.name, None)
@@ -1138,15 +1156,14 @@
)
is_supported_by_backend = is_pattern_dtype_config_supported_by_backend(
- pattern, matched_nodes, node_name_to_target_dtype, backend_config_dict)
+ pattern, matched_node_pattern, node_name_to_target_dtype, backend_config_dict)
if not skip_inserting_observers and is_supported_by_backend:
modules = dict(model.named_modules(remove_duplicate=False))
if node.op != 'output':
- assert matched_nodes is not None
+ assert matched_node_pattern is not None
# add matched nodes to the observed node name set
- for n in matched_nodes:
- observed_node_names.add(n.name)
+ add_matched_node_name_to_set(matched_node_pattern, observed_node_names)
# This is currently only used for equalization.
# Checks if the current node is in a branch in which the two
@@ -1176,7 +1193,8 @@
# TODO: this only works for sequential fusion right now, extend it
# it to automatically detect all input nodes based on the pattern
# need to change find_matches function to return this information
- is_input_node_of_the_pattern = matched_nodes[-1] is node
+ root_node = _default_root_node_getter(matched_node_pattern)
+ is_input_node_of_the_pattern = node is root_node
if is_input_node_of_the_pattern:
# this modifies node inplace
maybe_insert_input_observers_for_node(
@@ -1191,7 +1209,7 @@
node, equalization_qconfig, model, modules, graph,
node_name_to_target_dtype, is_quantized_branch)
- is_last_node_of_pattern = root_node is node
+ is_last_node_of_pattern = node is last_node
is_general_tensor_value_op = \
(qhandler is not None and qhandler.is_general_tensor_value_op())
@@ -1273,7 +1291,7 @@
"""
for (
node_name,
- (root_node, matched_nodes, pattern, qhandler, qconfig),
+ (root_node, _, pattern, qhandler, qconfig),
) in matches.items():
if qhandler is None:
continue
@@ -1375,6 +1393,7 @@
patterns: Dict[Pattern, QuantizeHandler] = {}
if backend_config_dict is None:
patterns = get_native_quant_patterns(additional_quant_patterns)
+ root_node_getter_mapping = {}
else:
patterns = get_pattern_to_quantize_handlers(backend_config_dict)
patterns = sorted_patterns_dict(patterns)
@@ -1397,6 +1416,9 @@
else:
index_dict[pattern] = [index] # type: ignore[index]
+ root_node_getter_mapping = \
+ get_fusion_pattern_to_root_node_getter(backend_config_dict)
+
convert_dict_to_ordered_dict(qconfig_dict)
convert_dict_to_ordered_dict(equalization_qconfig_dict)
qconfig_dict = update_qconfig_for_fusion(model, qconfig_dict)
@@ -1443,8 +1465,8 @@
custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class")
matches = find_matches(
- model.graph, modules, patterns, qconfig_map, standalone_module_names,
- standalone_module_classes, custom_module_classes)
+ model.graph, modules, patterns, root_node_getter_mapping, qconfig_map,
+ standalone_module_names, standalone_module_classes, custom_module_classes)
input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
"input_quantized_idxs", [])
diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py
index 44f5f83..f15794d 100644
--- a/torch/ao/quantization/fx/quantization_patterns.py
+++ b/torch/ao/quantization/fx/quantization_patterns.py
@@ -19,11 +19,20 @@
from .utils import (
all_node_args_have_no_tensors,
)
+from .quantization_types import NodePattern
from abc import ABC
import operator
from typing import Any, Callable, Dict, Optional
+# this is temporary, will be removed soon
+def _default_root_node_getter(node_pattern):
+ if node_pattern is None:
+ return node_pattern
+ while not isinstance(node_pattern, Node):
+ node_pattern = node_pattern[-1]
+ return node_pattern
+
# -------------------------
# Pattern Registrations
# -------------------------
@@ -34,18 +43,24 @@
class QuantizeHandler(ABC):
""" Base handler class for the quantizer patterns
"""
- def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
+ def __init__(
+ self,
+ node_pattern: NodePattern,
+ modules: Dict[str, torch.nn.Module],
+ root_node_getter: Callable = None):
""" Records pattern information in __init__, which will be used
in convert
"""
# this is an indicator of whether all the inputs are Node or not
# since some op might be quantized differently depending on whether
# all inputs are tensors or not, e.g. add/mul
- if isinstance(node, Node):
- self.num_tensor_args = len(node.args)
+ if root_node_getter is None:
+ root_node_getter = _default_root_node_getter
+ self.root_node = root_node_getter(node_pattern)
+ if isinstance(self.root_node, Node):
+ self.num_tensor_args = len(self.root_node.args)
else:
self.num_tensor_args = 0
- self.all_node_args_are_tensors = True
# TODO: can remove after the is_dynamic flag is defined, so that we can
# move embedding op to backend_config_dict
@@ -109,32 +124,19 @@
class BinaryOpQuantizeHandler(QuantizeHandler):
def __init__(
self,
- node: Node,
- modules: Dict[str, torch.nn.Module]):
- super().__init__(node, modules)
- self.relu_node = None
- if (
- node.op == 'call_function' and
- node.target in (torch.nn.functional.relu, torch.relu)
- ) or (
- node.op == 'call_module' and
- isinstance(modules[str(node.target)], torch.nn.ReLU)
- ):
- self.relu_node = node
- node = node.args[0] # type: ignore[assignment]
- self.binary_op_node = node
- self.binary_op = node.target
+ node_pattern: NodePattern,
+ modules: Dict[str, torch.nn.Module],
+ root_node_getter: Callable = None):
+ super().__init__(node_pattern, modules, root_node_getter)
# determine how many of the first two args are Tensors (versus scalars)
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
self.num_tensor_args = 0
cache_for_no_tensor_check: Dict[Node, bool] = dict()
- for arg_idx in range(len(self.binary_op_node.args)):
- arg = self.binary_op_node.args[arg_idx]
+ for arg_idx in range(len(self.root_node.args)):
+ arg = self.root_node.args[arg_idx]
if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg, modules, cache_for_no_tensor_check)):
self.num_tensor_args += 1
- self.all_node_args_are_tensors = \
- (self.num_tensor_args == len(self.binary_op_node.args))
def is_general_tensor_value_op(self) -> bool:
return self.num_tensor_args == 1
@@ -201,12 +203,6 @@
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_observer)
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_observer)
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
- def __init__(self,
- node: Node,
- modules: Dict[str, torch.nn.Module]):
- super().__init__(node, modules)
- self.node = node
-
# some qhandlers override the activations constructor
def get_activation_ctr(self, qconfig, pattern, is_training) -> Optional[Callable]:
act_dtype = activation_dtype(qconfig)