ns for fx: make unshadowed activation comparison work for N models (#52357)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52357

Refactor the NS for FX compare unshadowed activations API to be able
to work on N models and do arbitrary matching strategies.

We factor out a util which takes a model and a list of
nodes to extract weights for, with names to give the extracted
weights. The user can then call this util with a set
of nodes and names created in any way they want.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
```

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D26487270

fbshipit-source-id: 1372ef07b5f3ddc7cebdfb2dee0221a2facd0527
diff --git a/test/quantization/test_numeric_suite_fx.py b/test/quantization/test_numeric_suite_fx.py
index 724bd49..de5d422 100644
--- a/test/quantization/test_numeric_suite_fx.py
+++ b/test/quantization/test_numeric_suite_fx.py
@@ -676,7 +676,8 @@
         mq_ns(input_fp32)
 
         # check activation result correctness
-        act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
+        act_compare_dict = get_matching_activations(
+            'fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger)
         self.assertTrue(len(act_compare_dict) == 2)
         self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
 
@@ -727,7 +728,8 @@
         mq_ns(input_fp32)
 
         # check activation result correctness
-        act_compare_dict = get_matching_activations(mp_ns, mq_ns, OutputLogger)
+        act_compare_dict = get_matching_activations(
+            'fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger)
         self.assertTrue(len(act_compare_dict) == 2)
         self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
 
@@ -834,7 +836,8 @@
         sparse_nn_q(idx, offsets, x)
 
         # inspect results
-        act_compare_dict = get_matching_activations(sparse_nn, sparse_nn_q, OutputLogger)
+        act_compare_dict = get_matching_activations(
+            'fp32_prepared', sparse_nn, 'int8', sparse_nn_q, OutputLogger)
         self.assertTrue(len(act_compare_dict) == 3)
         self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
 
diff --git a/torch/quantization/ns/graph_passes.py b/torch/quantization/ns/graph_passes.py
index 25f9bb5..3447154 100644
--- a/torch/quantization/ns/graph_passes.py
+++ b/torch/quantization/ns/graph_passes.py
@@ -45,7 +45,7 @@
 
 def remove_observers_add_loggers(
     gm: GraphModule,
-    node_to_instrument_to_other_node_name: Dict[Node, Optional[str]],
+    node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]],
     logger_cls: Callable,
     model_name: str,
 ) -> GraphModule:
@@ -71,8 +71,8 @@
             # remove activation post process node
             env[node.name] = env[node.args[0].name]
 
-        elif node in node_to_instrument_to_other_node_name:
-            other_node_name = node_to_instrument_to_other_node_name[node]
+        elif node in node_to_instrument_to_ref_node_name:
+            other_node_name = node_to_instrument_to_ref_node_name[node]
             # ensure env is populated with base node
             env[node.name] = new_graph.node_copy(node, load_arg)
             # add the logger after the base node
diff --git a/torch/quantization/ns/numeric_suite_core_apis_fx.py b/torch/quantization/ns/numeric_suite_core_apis_fx.py
index a89867e..e7e18c3 100644
--- a/torch/quantization/ns/numeric_suite_core_apis_fx.py
+++ b/torch/quantization/ns/numeric_suite_core_apis_fx.py
@@ -103,7 +103,7 @@
         self,
         node_name: str,
         model_name: str,
-        other_node_name: Optional[str] = None,
+        ref_node_name: Optional[str] = None,
     ):
         super().__init__()
         self.stats: List[torch.Tensor] = []
@@ -111,17 +111,38 @@
         self.node_name = node_name
         # name of the model from which the node originated from
         self.model_name = model_name
-        # name of the other node with a matching Logger
+        # name of the reference node with a matching Logger
         # used to link node_a_copy -> logger_a to node_c -> logger_c
         # in a_shadows_b
-        self.other_node_name = other_node_name
+        self.ref_node_name = ref_node_name
 
     def forward(self, x: torch.Tensor):
         self.stats.append(x.detach())
         return x
 
     def __repr__(self):
-        return f"OutputLogger(node_name={self.node_name}, model_name={self.model_name}, other_node_name={self.other_node_name})"
+        return f"OutputLogger(node_name={self.node_name}, model_name={self.model_name}, ref_node_name={self.ref_node_name})"
+
+def prepare_single_model_output(
+    model_name: str,
+    model: GraphModule,
+    subgraphs_to_instrument: List[Tuple[Tuple[Node, Node], str]],
+    logger_cls: Callable,
+) -> GraphModule:
+
+    # TODO(future PR): do not observe nodes we do not care
+    #   about (both fp32, denylist, etc)
+    # Note: for matching activations we always use the end nodes,
+    # such as observing the output of relu in linear-relu
+    # Note: ref_node_name is set to None in model B's loggers,
+    # and set to the corresponding model B's node in model A's loggers.
+    node_to_instrument_to_ref_node_name: Dict[Node, Optional[str]] = {}
+    for (node_start, node_end), ref_node_name in subgraphs_to_instrument:
+        node_to_instrument_to_ref_node_name[node_end] = ref_node_name
+
+    model = remove_observers_add_loggers(
+        model, node_to_instrument_to_ref_node_name, logger_cls, model_name)
+    return model
 
 # Note: this is not a user facing API
 # TODO(future PR): wrap this in a user facing API which does not
@@ -133,35 +154,49 @@
     gm_b: GraphModule,
     logger_cls: Callable,
 ) -> Tuple[GraphModule, GraphModule]:
-
     matched_subgraph_pairs = get_matching_subgraph_pairs(gm_a, gm_b)
+    subgraphs_to_instrument_a = []
+    subgraphs_to_instrument_b = []
+    for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
+        subgraphs_to_instrument_a.append((subgraph_a, match_name))
+        subgraphs_to_instrument_b.append((subgraph_b, match_name))
 
-    node_to_instrument_to_other_node_name_a: Dict[Node, Optional[str]] = {}
-    node_to_instrument_to_other_node_name_b: Dict[Node, Optional[str]] = {}
-    for match_name, match in matched_subgraph_pairs.items():
-        (node_start_a, node_end_a), (node_start_b, node_end_b) = match
-        # TODO(future PR): do not observe pairs of nodes we do not care
-        #   about (both fp32, denylist, etc)
-        # Note: for matching activations we always use the end nodes,
-        # such as observing the output of relu in linear-relu
-        # Note: other_node_name is set to None in model B's loggers,
-        # and set to the corresponding model B's node in model A's loggers.
-        node_to_instrument_to_other_node_name_a[node_end_a] = node_end_b.name
-        node_to_instrument_to_other_node_name_b[node_end_b] = None
-
-    gm_a = remove_observers_add_loggers(
-        gm_a, node_to_instrument_to_other_node_name_a, logger_cls, name_a)
-    gm_b = remove_observers_add_loggers(
-        gm_b, node_to_instrument_to_other_node_name_b, logger_cls, name_b)
+    gm_a = prepare_single_model_output(
+        name_a, gm_a, subgraphs_to_instrument_a, logger_cls)
+    gm_b = prepare_single_model_output(
+        name_b, gm_b, subgraphs_to_instrument_b, logger_cls)
     return (gm_a, gm_b)
 
+def add_activation_info_to_dict(
+    model_name: str,
+    model: GraphModule,
+    results: Dict[str, Dict[str, List[torch.Tensor]]],
+    logger_cls: Callable,
+) -> None:
+    for gm_name, mod in model.named_modules():
+        # TODO(future PR): better check when scripted
+        is_logger = (
+            isinstance(mod, logger_cls)  # type: ignore
+            or (
+                isinstance(mod, torch.jit.RecursiveScriptModule)
+                and mod.original_name == 'OutputLogger'
+            )
+        )
+        if is_logger:
+            key = mod.ref_node_name + '.stats'
+            if key not in results:
+                results[key] = {}
+            results[key][model_name] = mod.stats
+
 # Note: this is not a user facing API
 # TODO(future PR): wrap this in a user facing API which does not
 #   expose FX types.
 # TODO(future PR): align on naming
 # this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
 def get_matching_activations(
+    model_name_a: str,
     gm_a: GraphModule,
+    model_name_b: str,
     gm_b: GraphModule,
     logger_cls: Callable,
 ) -> Dict[str, Dict[str, List[torch.Tensor]]]:
@@ -188,26 +223,11 @@
        the return type for calibrating with 1 input vs N inputs.
     3. `logger_cls` is included in the API for easy result extraction
     """
-    results: Dict[str, Dict[str, List[torch.Tensor]]] = \
-        collections.defaultdict(dict)
+    results: Dict[str, Dict[str, List[torch.Tensor]]] = {}
     for gm in (gm_a, gm_b):
-        for gm_name, mod in gm.named_modules():
-            # TODO(future PR): better check when scripted
-            is_logger = (
-                isinstance(mod, logger_cls)  # type: ignore
-                or (
-                    isinstance(mod, torch.jit.RecursiveScriptModule)
-                    and mod.original_name == 'OutputLogger'
-                )
-            )
-            if is_logger:
-                # If logger_obj.other_node_name is populated, then this logger
-                # is from model A, and other_node_name is the name from model B.
-                if mod.other_node_name is None:
-                    results[mod.node_name + '.stats'][mod.model_name] = mod.stats
-                else:
-                    results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats
-    return dict(results)
+        add_activation_info_to_dict(model_name_a, gm_a, results, logger_cls)
+        add_activation_info_to_dict(model_name_b, gm_b, results, logger_cls)
+    return results
 
 # Note: this is not a user facing API
 # TODO(future PR): wrap this in a user facing API which does not
@@ -253,10 +273,10 @@
             )
         )
         if is_logger:
-            # If logger_obj.other_node_name is populated, then this logger
-            # is from model A, and other_node_name is the name from model B.
-            if mod.other_node_name is None:
+            # If logger_obj.ref_node_name is populated, then this logger
+            # is from model A, and ref_node_name is the name from model B.
+            if mod.ref_node_name is None:
                 results[mod.node_name + '.stats'][mod.model_name] = mod.stats
             else:
-                results[mod.other_node_name + '.stats'][mod.model_name] = mod.stats
+                results[mod.ref_node_name + '.stats'][mod.model_name] = mod.stats
     return dict(results)