[easy] refactor signature flattening transform (#101886)

Move `ChangeInputOutputSignature` out of export function to avoid closed over variables that make dependencies hard to understand. Also rename it while we're at it.

Differential Revision: [D46029076](https://our.internmc.facebook.com/intern/diff/D46029076/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101886
Approved by: https://github.com/tugsbayasgalan
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index ee9fc1c..f416780 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -719,6 +719,44 @@
         )
 
 
+class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
+    def __init__(
+        self,
+        m: torch.fx.GraphModule,
+        arg_len: int,
+        matched_input_elements_positions: List[int],
+        matched_output_elements_positions: List[int],
+    ):
+        super().__init__(m)
+        self.new_args = [
+            super(FlattenInputOutputSignature, self).placeholder(f"arg{i}", (), {})
+            for i in range(0, arg_len)
+        ]
+        self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions)
+        self.matched_output_elements_positions = matched_output_elements_positions
+
+    def placeholder(self, target, args, kwargs):
+        arg = next(self.old_args_gen)
+        if "val" in self.current_node.meta:
+            arg.node.meta["val"] = self.current_node.meta["val"]
+        if "tensor_dict" in self.current_node.meta:
+            arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
+        return arg
+
+    def output(self, target, args, kwargs):
+        dynamo_result_flat = args[0]
+        lookup = [*dynamo_result_flat, *self.new_args]
+        new_result_flat = [lookup[i] for i in self.matched_output_elements_positions]
+        return super().output(target, (new_result_flat,), {})
+
+    def run_node(self, n):
+        self.current_node = n
+        r = super().run_node(n)
+        if "val" in self.current_node.meta:
+            r.node.meta["val"] = self.current_node.meta["val"]
+        return r
+
+
 def export(
     f: Callable[..., Any],
     *args,
@@ -924,42 +962,6 @@
     flat_both = list(graph_captured_result) + flat_args
     matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
 
-    class ChangeInputOutputSignature(torch.fx.interpreter.Transformer):
-        def __init__(
-            self,
-            m,
-        ):
-            super().__init__(m)
-            arg_len = len(flat_args)
-            self.new_args = [
-                super(ChangeInputOutputSignature, self).placeholder(f"arg{i}", (), {})
-                for i in range(0, arg_len)
-            ]
-            self.old_args_gen = (
-                self.new_args[i] for i in matched_input_elements_positions
-            )
-
-        def placeholder(self, target, args, kwargs):
-            arg = next(self.old_args_gen)
-            if "val" in self.current_node.meta:
-                arg.node.meta["val"] = self.current_node.meta["val"]
-            if "tensor_dict" in self.current_node.meta:
-                arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
-            return arg
-
-        def output(self, target, args, kwargs):
-            dynamo_result_flat = args[0]
-            lookup = [*dynamo_result_flat, *self.new_args]
-            new_result_flat = [lookup[i] for i in matched_output_elements_positions]
-            return super().output(target, (new_result_flat,), {})
-
-        def run_node(self, n):
-            self.current_node = n
-            r = super().run_node(n)
-            if "val" in self.current_node.meta:
-                r.node.meta["val"] = self.current_node.meta["val"]
-            return r
-
     # NB: This is mostly hitting the cache; Dynamo already converted these
     example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs]
 
@@ -1012,8 +1014,11 @@
                 # Wrap the internal error to the user-facing error
                 raise UserError(UserErrorType.DYNAMIC_CONTROL_FLOW, str(e))
 
-    new_graph = ChangeInputOutputSignature(
+    new_graph = FlattenInputOutputSignature(
         graph,
+        len(flat_args),
+        matched_input_elements_positions,
+        matched_output_elements_positions,
     ).transform()
 
     # Store constraints and inputs as metadata for user passes, e.g. turn constraints to runtime check