Improvements to FX Minimizer (#83833)

Summary: This diff improves the FX Minimizer for better error reports, and fixes a few other issues.

Test Plan: CI

Differential Revision: D38900309

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83833
Approved by: https://github.com/yuhc, https://github.com/Chillee
diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py
index 26df771..d63abb0 100644
--- a/torch/fx/passes/net_min_base.py
+++ b/torch/fx/passes/net_min_base.py
@@ -1,25 +1,29 @@
-from typing import Any, Callable, Tuple, Dict, Optional
 import logging
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import torch
 import torch.fx
-from torch.fx.node import map_arg
 from torch.fx._compatibility import compatibility
+from torch.fx.node import map_arg
 
 from .shape_prop import ShapeProp
 from .split_utils import split_by_tags
 from .tools_common import (
-    Tensors,
-    TensorOrTensors,
-    NodeList,
-    NodeSet,
     CALLABLE_NODE_OPS,
     FxNetAccFusionsFinder,
-    Names
+    Names,
+    NodeList,
+    NodeSet,
+    TensorOrTensors,
+    Tensors,
 )
-from dataclasses import dataclass
 
-__all__ = ['FxNetMinimizerBadModuleError', 'FxNetMinimizerRunFuncError', 'FxNetMinimizerResultMismatchError']
+__all__ = [
+    "FxNetMinimizerBadModuleError",
+    "FxNetMinimizerRunFuncError",
+    "FxNetMinimizerResultMismatchError",
+]
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -32,6 +36,7 @@
 
     pass
 
+
 @compatibility(is_backward_compatible=False)
 class FxNetMinimizerRunFuncError(Exception):
     """
@@ -40,6 +45,7 @@
 
     pass
 
+
 @compatibility(is_backward_compatible=False)
 class FxNetMinimizerResultMismatchError(Exception):
     """
@@ -48,6 +54,7 @@
 
     pass
 
+
 @dataclass
 class _MinimizerSettingBase:
     """
@@ -64,6 +71,7 @@
     `return_intermediate`: If true, when using `run_nodes()` function to run the
     model, intermediate results of all the ops will be returned as output.
     """
+
     accumulate_error: bool = False
     traverse_method: str = "sequential"
     find_all: bool = False
@@ -97,7 +105,9 @@
         self,
         module: torch.fx.GraphModule,
         sample_input: Tensors,
-        compare_fn: Callable[[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]],
+        compare_fn: Callable[
+            [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
+        ],
         settings: _MinimizerSettingBase,
     ):
         assert isinstance(module, torch.fx.GraphModule)
@@ -116,6 +126,12 @@
         # Stores the results of compare_fn
         self.results: Dict[Any, Any] = {}
 
+        # Stores the report for the runs
+        self.reports: List[List[str]] = []
+
+        # Current iteration
+        self.iteration: int = 0
+
         callable_nodes = {
             node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
         }
@@ -293,10 +309,7 @@
         return split_module, submodule_name
 
     def _run_and_compare(
-        self,
-        split_module: torch.fx.GraphModule,
-        submod_name: str,
-        output_names: Names
+        self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names
     ):
         """
         Run the submodule in `split_module` that has name `submod_name`
@@ -311,6 +324,13 @@
         submodule = getattr(split_module, submod_name)
         a_input, b_input = self._get_submod_inputs(split_module, submod_name)
 
+        if len(self.reports) == 0:
+            self.reports.append([])
+            self.iteration = 1
+
+        report = self.reports[self.iteration - 1]
+        report.append("Run and compare ...")
+
         if output_names:
             output_nodes: NodeList = []
             for node in submodule.graph.nodes:
@@ -339,52 +359,83 @@
         names: Names = output_names
         if output_names is None:
             names = [str(v) for v in result_key]
+
         numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
+
         self.results[result_key] = numeric_result
+        report.append(f"Numerical accuracy = {numeric_result}")
         if not bool_result:
+            report.append(f"Result mismatch for {result_key}")
             raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
 
-    def _binary_search_impl(self, nodes: NodeList) -> NodeSet:
+    def _binary_search_impl(
+        self, all_nodes: NodeList, start_idx: int, end_idx: int
+    ) -> NodeSet:
         """
         Recursive binary search implementation.
         """
+        nodes: NodeList = all_nodes[start_idx:end_idx]
+
+        report: List[str] = []
+        self.reports.append(report)
+        self.iteration += 1
+        report.append(f"Binary search iteration {self.iteration}.")
+        report.append(
+            f"From node index {start_idx} to {end_idx-1}. "
+            f"Size of the interested node list is {len(nodes)}"
+        )
+
         cur_nodes: NodeSet = set(nodes)
+
         for node in nodes:
             if node in self.fusions:
                 cur_nodes.update(self.fusions[node])
 
         try:
             split_module, submod_name = self._build_submodule(cur_nodes)
-            self._run_and_compare(
-                split_module,
-                submod_name,
-                []
-            )
+            self._run_and_compare(split_module, submod_name, [])
         except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
+
             if len(nodes) == 1:
+                report.append(
+                    f"This is the last node in the sub-module. "
+                    f"Search in the current branch is successful with culprit = {cur_nodes}."
+                )
+                self.print_report(report)
                 return cur_nodes
 
-            mid = len(nodes) // 2
+            report.append(
+                "Proceed to split and lower the halves of the current "
+                "sub-module individually."
+            )
+            self.print_report(report)
 
-            culprits = self._binary_search_impl(nodes[:mid])
-            if not self.settings.find_all:
+            mid = len(nodes) // 2
+            culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
+
+            if len(culprits) != 0 and not self.settings.find_all:
                 return culprits
 
-            culprits.update(self._binary_search_impl(nodes[mid:]))
+            culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
+
             if len(culprits) == 0:
-                raise FxNetMinimizerBadModuleError(
-                    "Found an error in a group of nodes, but was not able to minimize",
-                    nodes,
+                report.append(
+                    f"Further split and lowering found no errors. "
+                    f"Unable to minimize the submodule with list of nodes: {nodes}"
                 )
+                self.print_report(report)
+
             return culprits
         else:
+            report.append("No discrepancy found.")
+            self.print_report(report)
             return set()
 
     def _binary_traverse(self, nodes: NodeList) -> NodeSet:
         """
         Binary search on `nodes` for culprit.
         """
-        return self._binary_search_impl(nodes)
+        return self._binary_search_impl(nodes, 0, len(nodes))
 
     def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
         """
@@ -393,6 +444,12 @@
         culprits: NodeSet = set()
 
         for node in nodes:
+            report: List[str] = []
+            self.reports.append(report)
+            self.iteration += 1
+            report.append(f"Sequential traverse iteration {self.iteration}.")
+            report.append(f"Visit node: {node.name}")
+
             _LOGGER.info(f"Visit node: {node.name}")
             cur_nodes: NodeSet = {node}
 
@@ -401,15 +458,18 @@
 
             try:
                 split_module, submod_name = self._build_submodule(cur_nodes)
-                self._run_and_compare(
-                    split_module, submod_name, [node.name]
-                )
+                self._run_and_compare(split_module, submod_name, [node.name])
+                self.print_report(report)
             except (FxNetMinimizerResultMismatchError):
                 culprits.add(node)
+                report.append(f"Found culprit from numeric error: {node}")
+                self.print_report(report)
                 if not self.settings.find_all:
                     return culprits
             except (FxNetMinimizerRunFuncError):
                 culprits.update(cur_nodes)
+                report.append(f"Found culprit from run error: {node}")
+                self.print_report(report)
                 if not self.settings.find_all:
                     return culprits
 
@@ -426,19 +486,30 @@
             return culprits
 
         for node in nodes:
+            report: List[str] = []
+            self.reports.append(report)
+            self.iteration += 1
+            report.append(f"Accumulate traverse iteration {self.iteration}.")
+
             nodes_to_run.add(node)
 
             node_name = node.name
             if node_name is not None and isinstance(node_name, tuple):
                 node_name = node_name[0]
-            assert node_name is not None and isinstance(node_name, str), f"minimize: node_name: {node_name}"
+            assert node_name is not None and isinstance(
+                node_name, str
+            ), f"minimize: node_name: {node_name}"
+
+            report.append(f"Add node: {node_name}")
 
             try:
                 split_module, submod_name = self._build_submodule(nodes_to_run)
                 self._run_and_compare(split_module, submod_name, [node_name])
-            except (FxNetMinimizerResultMismatchError,
-                    FxNetMinimizerRunFuncError):
+                self.print_report(report)
+            except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
                 culprits.add(node)
+                report.append(f"Found culprit {node}")
+                self.print_report(report)
                 return culprits
 
         return culprits
@@ -500,7 +571,20 @@
         ) as e:
             print(e)
 
-    def minimize(self, start: Optional[str] = None, end: Optional[str] = None) -> NodeSet:
+    def print_report(self, report: List[str]):
+        for i in range(len(report)):
+            if i > 0:
+                print(" . " + report[i])
+            else:
+                print(report[i])
+
+    def print_reports(self):
+        for report in self.reports:
+            self.print_report(report)
+
+    def minimize(
+        self, start: Optional[str] = None, end: Optional[str] = None
+    ) -> NodeSet:
         """
         Minimizing the model from node with name `start` to node with name `end` base
         on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
@@ -519,6 +603,7 @@
 
         print(self.settings)
         print(self.module.graph)
+
         nodes = self._collect_nodes(start, end)
 
         if self.settings.traverse_method == "sequential":