[mtia] add module exporter to net minimizer (#115687)
Summary: add module exporter to net minimizer
Reviewed By: amylittleyang
Differential Revision: D52086699
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115687
Approved by: https://github.com/jfix71
diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py
index f61682c..3fcd096 100644
--- a/torch/fx/passes/net_min_base.py
+++ b/torch/fx/passes/net_min_base.py
@@ -109,12 +109,19 @@
[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
],
settings: _MinimizerSettingBase,
+ module_exporter: Optional[
+ Callable[
+ [List[torch.Tensor], torch.fx.GraphModule, str],
+ None
+ ]
+ ] = None,
):
assert isinstance(module, torch.fx.GraphModule)
self.module = module
self.sample_input = sample_input
self.compare_fn = compare_fn
+ self.module_exporter = module_exporter
self.settings = settings
# Stores outputs of run_a function
@@ -351,9 +358,15 @@
if node.op == "output":
result_key = map_arg(node.args, lambda x: x.name)
- a_result = self.run_a(submodule, a_input)
- b_result = self.run_b(submodule, b_input)
- self._store_outputs(a_result, b_result, submodule)
+ try:
+ a_result = self.run_a(submodule, a_input)
+ b_result = self.run_b(submodule, b_input)
+ self._store_outputs(a_result, b_result, submodule)
+ except Exception as e:
+ report.append(f"Exception raised when running {submod_name}: {e}")
+ raise FxNetMinimizerRunFuncError( # noqa: TRY200
+ f"Exception raised when running {submod_name}: {e}"
+ )
# Compare results
names: Names = output_names
@@ -366,6 +379,13 @@
report.append(f"Numerical accuracy = {numeric_result}")
if not bool_result:
report.append(f"Result mismatch for {result_key}")
+ if self.module_exporter:
+ self.module_exporter(
+ List[torch.Tensor](a_input), submodule, str(result_key[0]) + "_cpu",
+ )
+ self.module_exporter(
+ List[torch.Tensor](b_input), submodule, str(result_key[0]) + "_acc",
+ )
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
def _binary_search_impl(