modify split_by_tags to retain output order (#84136)

Summary: Currently `split_by_tags` determines submodule output order by iterating over `used_in_main`. Since this is a `Set`, insertion order is not retained so we run into problems with submodule output order being "randomized" & inconsistent between splits. By using `Dict[Node, None]` we can implement `used_in_main` as an ordered set so that output order is consistent when splitting the same model.

Test Plan: CI

Differential Revision: D39039268

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84136
Approved by: https://github.com/houseroad
diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py
index 0236d92..f3d8608 100644
--- a/torch/fx/passes/split_utils.py
+++ b/torch/fx/passes/split_utils.py
@@ -3,7 +3,7 @@
 
 import torch.fx
 from torch.fx.graph import map_arg
-from .tools_common import NodeList, NodeSet
+from .tools_common import NodeList
 from torch.fx._compatibility import compatibility
 from torch.fx.passes.utils import lift_subgraph_as_module, HolderModule
 
@@ -130,7 +130,7 @@
     all_components: List[Component] = []
 
     # Stores nodes that will be used in main graph.
-    used_in_main: NodeSet = set()
+    used_in_main: Dict[torch.fx.Node, None] = {}
 
     # Main graph after split.
     main_g = torch.fx.Graph()
@@ -208,7 +208,7 @@
                 comp.input_placeholders.append(
                     comp.graph.placeholder(x.name, type_expr=x.type)
                 )
-                used_in_main.add(x)
+                used_in_main[x] = None
 
             return comp.input_placeholders[
                 next(i for i, y in enumerate(comp.orig_inputs) if x is y)
@@ -231,7 +231,7 @@
         else:
             # All component results consumed by the output node should be
             # marked as "used in main".
-            used_in_main.add(x)
+            used_in_main[x] = None
 
     # If a node is used in main graph then we mark it as an output in the component
     # it belongs to.