Improvements to FX Minimizer (#83833)
Summary: This diff improves the FX Minimizer for better error reports, and fixes a few other issues.
Test Plan: CI
Differential Revision: D38900309
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83833
Approved by: https://github.com/yuhc, https://github.com/Chillee
diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py
index 26df771..d63abb0 100644
--- a/torch/fx/passes/net_min_base.py
+++ b/torch/fx/passes/net_min_base.py
@@ -1,25 +1,29 @@
-from typing import Any, Callable, Tuple, Dict, Optional
import logging
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.fx
-from torch.fx.node import map_arg
from torch.fx._compatibility import compatibility
+from torch.fx.node import map_arg
from .shape_prop import ShapeProp
from .split_utils import split_by_tags
from .tools_common import (
- Tensors,
- TensorOrTensors,
- NodeList,
- NodeSet,
CALLABLE_NODE_OPS,
FxNetAccFusionsFinder,
- Names
+ Names,
+ NodeList,
+ NodeSet,
+ TensorOrTensors,
+ Tensors,
)
-from dataclasses import dataclass
-__all__ = ['FxNetMinimizerBadModuleError', 'FxNetMinimizerRunFuncError', 'FxNetMinimizerResultMismatchError']
+__all__ = [
+ "FxNetMinimizerBadModuleError",
+ "FxNetMinimizerRunFuncError",
+ "FxNetMinimizerResultMismatchError",
+]
_LOGGER = logging.getLogger(__name__)
@@ -32,6 +36,7 @@
pass
+
@compatibility(is_backward_compatible=False)
class FxNetMinimizerRunFuncError(Exception):
"""
@@ -40,6 +45,7 @@
pass
+
@compatibility(is_backward_compatible=False)
class FxNetMinimizerResultMismatchError(Exception):
"""
@@ -48,6 +54,7 @@
pass
+
@dataclass
class _MinimizerSettingBase:
"""
@@ -64,6 +71,7 @@
`return_intermediate`: If true, when using `run_nodes()` function to run the
model, intermediate results of all the ops will be returned as output.
"""
+
accumulate_error: bool = False
traverse_method: str = "sequential"
find_all: bool = False
@@ -97,7 +105,9 @@
self,
module: torch.fx.GraphModule,
sample_input: Tensors,
- compare_fn: Callable[[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]],
+ compare_fn: Callable[
+ [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
+ ],
settings: _MinimizerSettingBase,
):
assert isinstance(module, torch.fx.GraphModule)
@@ -116,6 +126,12 @@
# Stores the results of compare_fn
self.results: Dict[Any, Any] = {}
+ # Stores the report for the runs
+ self.reports: List[List[str]] = []
+
+ # Current iteration
+ self.iteration: int = 0
+
callable_nodes = {
node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
}
@@ -293,10 +309,7 @@
return split_module, submodule_name
def _run_and_compare(
- self,
- split_module: torch.fx.GraphModule,
- submod_name: str,
- output_names: Names
+ self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names
):
"""
Run the submodule in `split_module` that has name `submod_name`
@@ -311,6 +324,13 @@
submodule = getattr(split_module, submod_name)
a_input, b_input = self._get_submod_inputs(split_module, submod_name)
+ if len(self.reports) == 0:
+ self.reports.append([])
+ self.iteration = 1
+
+ report = self.reports[self.iteration - 1]
+ report.append("Run and compare ...")
+
if output_names:
output_nodes: NodeList = []
for node in submodule.graph.nodes:
@@ -339,52 +359,83 @@
names: Names = output_names
if output_names is None:
names = [str(v) for v in result_key]
+
numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
+
self.results[result_key] = numeric_result
+ report.append(f"Numerical accuracy = {numeric_result}")
if not bool_result:
+ report.append(f"Result mismatch for {result_key}")
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
- def _binary_search_impl(self, nodes: NodeList) -> NodeSet:
+ def _binary_search_impl(
+ self, all_nodes: NodeList, start_idx: int, end_idx: int
+ ) -> NodeSet:
"""
Recursive binary search implementation.
"""
+ nodes: NodeList = all_nodes[start_idx:end_idx]
+
+ report: List[str] = []
+ self.reports.append(report)
+ self.iteration += 1
+ report.append(f"Binary search iteration {self.iteration}.")
+ report.append(
+ f"From node index {start_idx} to {end_idx-1}. "
+ f"Size of the interested node list is {len(nodes)}"
+ )
+
cur_nodes: NodeSet = set(nodes)
+
for node in nodes:
if node in self.fusions:
cur_nodes.update(self.fusions[node])
try:
split_module, submod_name = self._build_submodule(cur_nodes)
- self._run_and_compare(
- split_module,
- submod_name,
- []
- )
+ self._run_and_compare(split_module, submod_name, [])
except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
+
if len(nodes) == 1:
+ report.append(
+ f"This is the last node in the sub-module. "
+ f"Search in the current branch is successful with culprit = {cur_nodes}."
+ )
+ self.print_report(report)
return cur_nodes
- mid = len(nodes) // 2
+ report.append(
+ "Proceed to split and lower the halves of the current "
+ "sub-module individually."
+ )
+ self.print_report(report)
- culprits = self._binary_search_impl(nodes[:mid])
- if not self.settings.find_all:
+ mid = len(nodes) // 2
+ culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
+
+ if len(culprits) != 0 and not self.settings.find_all:
return culprits
- culprits.update(self._binary_search_impl(nodes[mid:]))
+ culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
+
if len(culprits) == 0:
- raise FxNetMinimizerBadModuleError(
- "Found an error in a group of nodes, but was not able to minimize",
- nodes,
+ report.append(
+ f"Further split and lowering found no errors. "
+ f"Unable to minimize the submodule with list of nodes: {nodes}"
)
+ self.print_report(report)
+
return culprits
else:
+ report.append("No discrepancy found.")
+ self.print_report(report)
return set()
def _binary_traverse(self, nodes: NodeList) -> NodeSet:
"""
Binary search on `nodes` for culprit.
"""
- return self._binary_search_impl(nodes)
+ return self._binary_search_impl(nodes, 0, len(nodes))
def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
"""
@@ -393,6 +444,12 @@
culprits: NodeSet = set()
for node in nodes:
+ report: List[str] = []
+ self.reports.append(report)
+ self.iteration += 1
+ report.append(f"Sequential traverse iteration {self.iteration}.")
+ report.append(f"Visit node: {node.name}")
+
_LOGGER.info(f"Visit node: {node.name}")
cur_nodes: NodeSet = {node}
@@ -401,15 +458,18 @@
try:
split_module, submod_name = self._build_submodule(cur_nodes)
- self._run_and_compare(
- split_module, submod_name, [node.name]
- )
+ self._run_and_compare(split_module, submod_name, [node.name])
+ self.print_report(report)
except (FxNetMinimizerResultMismatchError):
culprits.add(node)
+ report.append(f"Found culprit from numeric error: {node}")
+ self.print_report(report)
if not self.settings.find_all:
return culprits
except (FxNetMinimizerRunFuncError):
culprits.update(cur_nodes)
+ report.append(f"Found culprit from run error: {node}")
+ self.print_report(report)
if not self.settings.find_all:
return culprits
@@ -426,19 +486,30 @@
return culprits
for node in nodes:
+ report: List[str] = []
+ self.reports.append(report)
+ self.iteration += 1
+ report.append(f"Accumulate traverse iteration {self.iteration}.")
+
nodes_to_run.add(node)
node_name = node.name
if node_name is not None and isinstance(node_name, tuple):
node_name = node_name[0]
- assert node_name is not None and isinstance(node_name, str), f"minimize: node_name: {node_name}"
+ assert node_name is not None and isinstance(
+ node_name, str
+ ), f"minimize: node_name: {node_name}"
+
+ report.append(f"Add node: {node_name}")
try:
split_module, submod_name = self._build_submodule(nodes_to_run)
self._run_and_compare(split_module, submod_name, [node_name])
- except (FxNetMinimizerResultMismatchError,
- FxNetMinimizerRunFuncError):
+ self.print_report(report)
+ except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
culprits.add(node)
+ report.append(f"Found culprit {node}")
+ self.print_report(report)
return culprits
return culprits
@@ -500,7 +571,20 @@
) as e:
print(e)
- def minimize(self, start: Optional[str] = None, end: Optional[str] = None) -> NodeSet:
+ def print_report(self, report: List[str]):
+ for i in range(len(report)):
+ if i > 0:
+ print(" . " + report[i])
+ else:
+ print(report[i])
+
+ def print_reports(self):
+ for report in self.reports:
+ self.print_report(report)
+
+ def minimize(
+ self, start: Optional[str] = None, end: Optional[str] = None
+ ) -> NodeSet:
"""
Minimizing the model from node with name `start` to node with name `end` base
on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
@@ -519,6 +603,7 @@
print(self.settings)
print(self.module.graph)
+
nodes = self._collect_nodes(start, end)
if self.settings.traverse_method == "sequential":