[easy] refactor signature flattening transform (#101886)
Move `ChangeInputOutputSignature` out of export function to avoid closed over variables that make dependencies hard to understand. Also rename it while we're at it.
Differential Revision: [D46029076](https://our.internmc.facebook.com/intern/diff/D46029076/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101886
Approved by: https://github.com/tugsbayasgalan
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index ee9fc1c..f416780 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -719,6 +719,44 @@
)
+class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
+ def __init__(
+ self,
+ m: torch.fx.GraphModule,
+ arg_len: int,
+ matched_input_elements_positions: List[int],
+ matched_output_elements_positions: List[int],
+ ):
+ super().__init__(m)
+ self.new_args = [
+ super(FlattenInputOutputSignature, self).placeholder(f"arg{i}", (), {})
+ for i in range(0, arg_len)
+ ]
+ self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions)
+ self.matched_output_elements_positions = matched_output_elements_positions
+
+ def placeholder(self, target, args, kwargs):
+ arg = next(self.old_args_gen)
+ if "val" in self.current_node.meta:
+ arg.node.meta["val"] = self.current_node.meta["val"]
+ if "tensor_dict" in self.current_node.meta:
+ arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
+ return arg
+
+ def output(self, target, args, kwargs):
+ dynamo_result_flat = args[0]
+ lookup = [*dynamo_result_flat, *self.new_args]
+ new_result_flat = [lookup[i] for i in self.matched_output_elements_positions]
+ return super().output(target, (new_result_flat,), {})
+
+ def run_node(self, n):
+ self.current_node = n
+ r = super().run_node(n)
+ if "val" in self.current_node.meta:
+ r.node.meta["val"] = self.current_node.meta["val"]
+ return r
+
+
def export(
f: Callable[..., Any],
*args,
@@ -924,42 +962,6 @@
flat_both = list(graph_captured_result) + flat_args
matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
- class ChangeInputOutputSignature(torch.fx.interpreter.Transformer):
- def __init__(
- self,
- m,
- ):
- super().__init__(m)
- arg_len = len(flat_args)
- self.new_args = [
- super(ChangeInputOutputSignature, self).placeholder(f"arg{i}", (), {})
- for i in range(0, arg_len)
- ]
- self.old_args_gen = (
- self.new_args[i] for i in matched_input_elements_positions
- )
-
- def placeholder(self, target, args, kwargs):
- arg = next(self.old_args_gen)
- if "val" in self.current_node.meta:
- arg.node.meta["val"] = self.current_node.meta["val"]
- if "tensor_dict" in self.current_node.meta:
- arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
- return arg
-
- def output(self, target, args, kwargs):
- dynamo_result_flat = args[0]
- lookup = [*dynamo_result_flat, *self.new_args]
- new_result_flat = [lookup[i] for i in matched_output_elements_positions]
- return super().output(target, (new_result_flat,), {})
-
- def run_node(self, n):
- self.current_node = n
- r = super().run_node(n)
- if "val" in self.current_node.meta:
- r.node.meta["val"] = self.current_node.meta["val"]
- return r
-
# NB: This is mostly hitting the cache; Dynamo already converted these
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
@@ -1012,8 +1014,11 @@
# Wrap the internal error to the user-facing error
raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
- new_graph = ChangeInputOutputSignature(
+ new_graph = FlattenInputOutputSignature(
graph,
+ len(flat_args),
+ matched_input_elements_positions,
+ matched_output_elements_positions,
).transform()
# Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check