[quant][refactor] Refactor find_matches for easier future extension (#74878)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74878

Previously we record the matched node as a list of nodes: `List[Node]`, this does not generalize
to a graph, which is needed for future use cases, in this PR we changed the recorded node as
NodePattern instead, currently defined as
```
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
```
but can be more general.

This will allow us to support more general patterns with backend_config_dict api, and is also needed
for BinaryOpQuantizeHandler refactor

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D35203616

fbshipit-source-id: f4bf5b056cfc0955455eea9c2bf1ac9f6dde3974
(cherry picked from commit b290c047e1861bbb62fb1bb576761e801b210220)
diff --git a/torch/ao/quantization/fx/backend_config/quantize_handler.py b/torch/ao/quantization/fx/backend_config/quantize_handler.py
index fe932e3..1026844 100644
--- a/torch/ao/quantization/fx/backend_config/quantize_handler.py
+++ b/torch/ao/quantization/fx/backend_config/quantize_handler.py
@@ -1,14 +1,18 @@
 import torch
-from typing import Dict
-from torch.fx.graph import Node
+from typing import Dict, Callable
 from .observation_type import ObservationType
 from ..quantization_patterns import QuantizeHandler
+from ..quantization_types import NodePattern
 
 def get_quantize_handler_cls(observation_type, dtype_configs):
 
     class ConfigurableQuantizeHandler(QuantizeHandler):
-        def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
-            super().__init__(node, modules)
+        def __init__(
+                self,
+                node_pattern: NodePattern,
+                modules: Dict[str, torch.nn.Module],
+                root_node_getter: Callable = None):
+            super().__init__(node_pattern, modules, root_node_getter)
             self.observation_type = observation_type
             self.dtype_configs = dtype_configs
 
diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py
index 6dffcc7..ee28a2a 100644
--- a/torch/ao/quantization/fx/fuse.py
+++ b/torch/ao/quantization/fx/fuse.py
@@ -111,6 +111,7 @@
     # a map from node to the matched subpattern
     node_to_subpattern: Dict[Node, Any] = {}
 
+    # TODO: dedup with quantization matching function in match_utils.py
     def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
         if isinstance(pattern, tuple):
             s, *args = pattern
diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py
index 876bc39..38fb6e8 100644
--- a/torch/ao/quantization/fx/match_utils.py
+++ b/torch/ao/quantization/fx/match_utils.py
@@ -76,6 +76,7 @@
         graph: Graph,
         modules: Dict[str, torch.nn.Module],
         patterns: Dict[Pattern, QuantizeHandler],
+        root_node_getter_mapping: Dict[Pattern, Callable],
         qconfig_map: Dict[str, QConfigAny],
         standalone_module_names: List[str] = None,
         standalone_module_classes: List[Callable] = None,
@@ -114,29 +115,80 @@
     match_map: Dict[str, MatchResult] = {}
     all_matched : Set[str] = set()
 
-    def record_match(pattern, node, matched):
+    def _recursive_record_node_in_match_map(
+            last_node,
+            match_map,
+            node_pattern,
+            matched_node_pattern,
+            pattern,
+            match_value,
+            qconfig):
+        if isinstance(node_pattern, Node):
+            match_map[node_pattern.name] = (
+                last_node, matched_node_pattern, pattern, match_value, qconfig)
+        else:
+            for n in node_pattern:
+                _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value, qconfig)
+
+    # TODO: 1. merge with fuse matcher 2. document the code
+    def record_match(
+            pattern,
+            node,
+            last_node,
+            matched_node_pattern,
+            match_map):
         if isinstance(pattern, tuple):
             s, *args = pattern
-            record_match(s, node, matched)
+            current_node_pattern: List[Node] = []
+            record_match(
+                s,
+                node,
+                last_node,
+                matched_node_pattern,
+                match_map)
             if pattern[0] is not getattr:
                 for subpattern, arg in zip(args, node.args):
-                    record_match(subpattern, arg, matched)
+                    record_match(
+                        subpattern,
+                        arg,
+                        node,
+                        current_node_pattern,
+                        match_map)
+            if len(current_node_pattern) > 1:
+                matched_node_pattern.append(tuple(current_node_pattern))
+            else:
+                matched_node_pattern.append(current_node_pattern[0])
         else:
-            matched.append(node)
+            matched_node_pattern.append(node)
 
-    cache_for_no_tensor_check: Dict[Node, bool] = dict()
     for node in reversed(graph.nodes):
         if node.name not in match_map and node.name not in all_matched:
-            for pattern, value in patterns.items():
-                if is_match(modules, node, pattern):
-                    matched: List[Any] = []
-                    record_match(pattern, node, matched)
-                    for n in matched:
-                        match_map[n.name] = (
-                            node, matched, pattern, value(node, modules),  # type: ignore[operator]
-                            qconfig_map[n.name])
-                        all_matched.add(n.name)
-                    # break after finding the first match
+            for pattern, quantize_handler_cls in patterns.items():
+                root_node_getter = root_node_getter_mapping.get(pattern, None)
+                if is_match(modules, node, pattern) and node.name not in match_map:
+                    matched_node_pattern: List[Node] = []
+                    record_match(
+                        pattern,
+                        node,
+                        node,
+                        matched_node_pattern,
+                        match_map)
+                    quantize_handler = quantize_handler_cls(  # type: ignore[operator]
+                        matched_node_pattern,
+                        modules,
+                        root_node_getter)
+                    last_node = node
+                    # record the match for all nodes in the pattern
+                    _recursive_record_node_in_match_map(
+                        last_node,
+                        match_map,
+                        # we need to record all nodes in the matched pattern in the match_map
+                        matched_node_pattern,
+                        # this is a part of the value corresponding to the node
+                        matched_node_pattern,
+                        pattern,
+                        quantize_handler,
+                        qconfig_map[node.name])
                     break
 
     # add custom module instances to the match result
@@ -146,7 +198,7 @@
            type(modules[node.target]) in custom_module_classes:
             custom_module_qconfig = qconfig_map[node.name]
             match_map[node.name] = (
-                node, [node], None, CustomModuleQuantizeHandler(node, modules),
+                node, node, None, CustomModuleQuantizeHandler(node, modules),
                 custom_module_qconfig)
 
     def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
@@ -164,7 +216,7 @@
             # add node to matched nodes
             custom_module_qconfig = qconfig_map[node.name]
             match_map[node.name] = (
-                node, [node], None,
+                node, node, None,
                 StandaloneModuleQuantizeHandler(node, modules),
                 custom_module_qconfig)
 
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index e23446d..72538e2 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -34,7 +34,10 @@
     StandaloneModuleQuantizeHandler,
 )
 
-from .quantization_types import Pattern
+from .quantization_types import (
+    Pattern,
+    NodePattern
+)
 
 from ._equalize import (
     is_equalization_observer,
@@ -90,6 +93,7 @@
     get_pattern_to_input_type_to_index,
     get_module_to_qat_module,
     get_native_quant_patterns,
+    get_fusion_pattern_to_root_node_getter,
 )
 
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Set
@@ -180,7 +184,7 @@
 
 def is_pattern_dtype_config_supported_by_backend(
     pattern: Optional[Pattern],
-    matched_nodes: Optional[List[Node]],
+    matched_node_pattern: Optional[NodePattern],
     node_name_to_target_dtype: Dict[str, Dict[str, Optional[Union[torch.dtype, type]]]],
     backend_config_dict: Optional[Dict[str, Any]]
 ) -> bool:
@@ -189,14 +193,15 @@
     """
     if backend_config_dict is None or pattern is None:
         return True
-    assert matched_nodes is not None and len(matched_nodes) >= 1
+    assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
     pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config_dict)
     dtype_configs: List[Dict[str, torch.dtype]] = pattern_to_dtype_configs.get(pattern, [])
 
-    # TODO: this only checks one input and one output, need to generalize to multiple
+    # TODO: this only works for one input and one output patterns, need to generalize to multiple
     # inputs/output
-    input_node = matched_nodes[-1]
-    output_node = matched_nodes[0]
+    root_node = _default_root_node_getter(matched_node_pattern)
+    input_node = root_node
+    output_node = matched_node_pattern[0]
     for dtype_config in dtype_configs:
         # check if arg dtype are supported
         supported = True
@@ -247,6 +252,19 @@
         module_to_qat_module: Dict[Callable, Callable]) -> None:
     convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
 
+def add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
+    if isinstance(matched_node_pattern, Node):
+        s.add(matched_node_pattern.name)
+    elif isinstance(matched_node_pattern, (list, tuple)):
+        for maybe_node in matched_node_pattern:
+            add_matched_node_name_to_set(maybe_node, s)
+
+# this is temporary, will be removed soon
+def _default_root_node_getter(node_pattern):
+    while not isinstance(node_pattern, Node):
+        node_pattern = node_pattern[-1]
+    return node_pattern
+
 # TODO: remove observed_op, looks like it's not used
 def insert_observer(
     node: Node,
@@ -686,7 +704,7 @@
 
     If `node` does not need an output observer, returns None.
     """
-    root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+    root_node, _, pattern, qhandler, qconfig = matches.get(
         node.name, (None, None, None, None, None))
 
     if qhandler is None:
@@ -839,7 +857,7 @@
     node_name_to_target_dtype[node.name]["input_activation_dtype"] = target_dtype
     node_name_to_target_dtype[node.name]["output_activation_dtype"] = target_dtype
     # if this is a copy node, propagate to first arg
-    root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+    root_node, _, pattern, qhandler, qconfig = matches.get(
         node.name, (None, None, None, None, None))
     if qhandler is not None and qhandler.is_general_tensor_value_op():
         prev_node = node.args[0]
@@ -1078,7 +1096,7 @@
     # other nodes output dtype is specified by the qconfig
     modules = dict(model.named_modules(remove_duplicate=False))
     for node in model.graph.nodes:
-        root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+        root_node, _, pattern, qhandler, qconfig = matches.get(
             node.name, (None, None, None, None, None))
         node_name_to_target_dtype[node.name] = get_target_activation_dtype_for_node(
             node, qconfig, inputs_seen_counter, outputs_seen_counter,
@@ -1119,7 +1137,7 @@
 
         elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
             # check for matches
-            root_node, matched_nodes, pattern, qhandler, qconfig = matches.get(
+            last_node, matched_node_pattern, pattern, qhandler, qconfig = matches.get(
                 node.name, (None, None, None, None, None))
             equalization_qconfig = equalization_config_map.get(node.name, None)
 
@@ -1138,15 +1156,14 @@
             )
 
             is_supported_by_backend = is_pattern_dtype_config_supported_by_backend(
-                pattern, matched_nodes, node_name_to_target_dtype, backend_config_dict)
+                pattern, matched_node_pattern, node_name_to_target_dtype, backend_config_dict)
 
             if not skip_inserting_observers and is_supported_by_backend:
                 modules = dict(model.named_modules(remove_duplicate=False))
                 if node.op != 'output':
-                    assert matched_nodes is not None
+                    assert matched_node_pattern is not None
                     # add matched nodes to the observed node name set
-                    for n in matched_nodes:
-                        observed_node_names.add(n.name)
+                    add_matched_node_name_to_set(matched_node_pattern, observed_node_names)
 
                     # This is currently only used for equalization.
                     # Checks if the current node is in a branch in which the two
@@ -1176,7 +1193,8 @@
                     # TODO: this only works for sequential fusion right now, extend it
                     # it to automatically detect all input nodes based on the pattern
                     # need to change find_matches function to return this information
-                    is_input_node_of_the_pattern = matched_nodes[-1] is node
+                    root_node = _default_root_node_getter(matched_node_pattern)
+                    is_input_node_of_the_pattern = node is root_node
                     if is_input_node_of_the_pattern:
                         # this modifies node inplace
                         maybe_insert_input_observers_for_node(
@@ -1191,7 +1209,7 @@
                             node, equalization_qconfig, model, modules, graph,
                             node_name_to_target_dtype, is_quantized_branch)
 
-                    is_last_node_of_pattern = root_node is node
+                    is_last_node_of_pattern = node is last_node
                     is_general_tensor_value_op = \
                         (qhandler is not None and qhandler.is_general_tensor_value_op())
 
@@ -1273,7 +1291,7 @@
     """
     for (
         node_name,
-        (root_node, matched_nodes, pattern, qhandler, qconfig),
+        (root_node, _, pattern, qhandler, qconfig),
     ) in matches.items():
         if qhandler is None:
             continue
@@ -1375,6 +1393,7 @@
     patterns: Dict[Pattern, QuantizeHandler] = {}
     if backend_config_dict is None:
         patterns = get_native_quant_patterns(additional_quant_patterns)
+        root_node_getter_mapping = {}
     else:
         patterns = get_pattern_to_quantize_handlers(backend_config_dict)
         patterns = sorted_patterns_dict(patterns)
@@ -1397,6 +1416,9 @@
                 else:
                     index_dict[pattern] = [index]  # type: ignore[index]
 
+        root_node_getter_mapping = \
+            get_fusion_pattern_to_root_node_getter(backend_config_dict)
+
     convert_dict_to_ordered_dict(qconfig_dict)
     convert_dict_to_ordered_dict(equalization_qconfig_dict)
     qconfig_dict = update_qconfig_for_fusion(model, qconfig_dict)
@@ -1443,8 +1465,8 @@
     custom_module_classes = get_custom_module_class_keys(
         prepare_custom_config_dict, "float_to_observed_custom_module_class")
     matches = find_matches(
-        model.graph, modules, patterns, qconfig_map, standalone_module_names,
-        standalone_module_classes, custom_module_classes)
+        model.graph, modules, patterns, root_node_getter_mapping, qconfig_map,
+        standalone_module_names, standalone_module_classes, custom_module_classes)
 
     input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
         "input_quantized_idxs", [])
diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py
index 44f5f83..f15794d 100644
--- a/torch/ao/quantization/fx/quantization_patterns.py
+++ b/torch/ao/quantization/fx/quantization_patterns.py
@@ -19,11 +19,20 @@
 from .utils import (
     all_node_args_have_no_tensors,
 )
+from .quantization_types import NodePattern
 
 from abc import ABC
 import operator
 from typing import Any, Callable, Dict, Optional
 
+# this is temporary, will be removed soon
+def _default_root_node_getter(node_pattern):
+    if node_pattern is None:
+        return node_pattern
+    while not isinstance(node_pattern, Node):
+        node_pattern = node_pattern[-1]
+    return node_pattern
+
 # -------------------------
 # Pattern Registrations
 # -------------------------
@@ -34,18 +43,24 @@
 class QuantizeHandler(ABC):
     """ Base handler class for the quantizer patterns
     """
-    def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
+    def __init__(
+            self,
+            node_pattern: NodePattern,
+            modules: Dict[str, torch.nn.Module],
+            root_node_getter: Callable = None):
         """ Records pattern information in __init__, which will be used
         in convert
         """
         # this is an indicator of whether all the inputs are Node or not
         # since some op might be quantized differently depending on whether
         # all inputs are tensors or not, e.g. add/mul
-        if isinstance(node, Node):
-            self.num_tensor_args = len(node.args)
+        if root_node_getter is None:
+            root_node_getter = _default_root_node_getter
+        self.root_node = root_node_getter(node_pattern)
+        if isinstance(self.root_node, Node):
+            self.num_tensor_args = len(self.root_node.args)
         else:
             self.num_tensor_args = 0
-        self.all_node_args_are_tensors = True
 
     # TODO: can remove after the is_dynamic flag is defined, so that we can
     # move embedding op to backend_config_dict
@@ -109,32 +124,19 @@
 class BinaryOpQuantizeHandler(QuantizeHandler):
     def __init__(
             self,
-            node: Node,
-            modules: Dict[str, torch.nn.Module]):
-        super().__init__(node, modules)
-        self.relu_node = None
-        if (
-            node.op == 'call_function' and
-                node.target in (torch.nn.functional.relu, torch.relu)
-        ) or (
-            node.op == 'call_module' and
-                isinstance(modules[str(node.target)], torch.nn.ReLU)
-        ):
-            self.relu_node = node
-            node = node.args[0]  # type: ignore[assignment]
-        self.binary_op_node = node
-        self.binary_op = node.target
+            node_pattern: NodePattern,
+            modules: Dict[str, torch.nn.Module],
+            root_node_getter: Callable = None):
+        super().__init__(node_pattern, modules, root_node_getter)
 
         # determine how many of the first two args are Tensors (versus scalars)
         # this distinguishes things like "x + y" from "x + 2" or "2 + x"
         self.num_tensor_args = 0
         cache_for_no_tensor_check: Dict[Node, bool] = dict()
-        for arg_idx in range(len(self.binary_op_node.args)):
-            arg = self.binary_op_node.args[arg_idx]
+        for arg_idx in range(len(self.root_node.args)):
+            arg = self.root_node.args[arg_idx]
             if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg, modules, cache_for_no_tensor_check)):
                 self.num_tensor_args += 1
-        self.all_node_args_are_tensors = \
-            (self.num_tensor_args == len(self.binary_op_node.args))
 
     def is_general_tensor_value_op(self) -> bool:
         return self.num_tensor_args == 1
@@ -201,12 +203,6 @@
 @register_quant_pattern('tanh', default_symmetric_fixed_qparams_observer)
 @register_quant_pattern('tanh_', default_symmetric_fixed_qparams_observer)
 class FixedQParamsOpQuantizeHandler(QuantizeHandler):
-    def __init__(self,
-                 node: Node,
-                 modules: Dict[str, torch.nn.Module]):
-        super().__init__(node, modules)
-        self.node = node
-
     # some qhandlers override the activations constructor
     def get_activation_ctr(self, qconfig, pattern, is_training) -> Optional[Callable]:
         act_dtype = activation_dtype(qconfig)