modify split_by_tags to retain output order (#84136)
Summary: Currently `split_by_tags` determines submodule output order by iterating over `used_in_main`. Since this is a `Set`, insertion order is not retained so we run into problems with submodule output order being "randomized" & inconsistent between splits. By using `Dict[Node, None]` we can implement `used_in_main` as an ordered set so that output order is consistent when splitting the same model.
Test Plan: CI
Differential Revision: D39039268
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84136
Approved by: https://github.com/houseroad
diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py
index 0236d92..f3d8608 100644
--- a/torch/fx/passes/split_utils.py
+++ b/torch/fx/passes/split_utils.py
@@ -3,7 +3,7 @@
import torch.fx
from torch.fx.graph import map_arg
-from .tools_common import NodeList, NodeSet
+from .tools_common import NodeList
from torch.fx._compatibility import compatibility
from torch.fx.passes.utils import lift_subgraph_as_module, HolderModule
@@ -130,7 +130,7 @@
all_components: List[Component] = []
# Stores nodes that will be used in main graph.
- used_in_main: NodeSet = set()
+ used_in_main: Dict[torch.fx.Node, None] = {}
# Main graph after split.
main_g = torch.fx.Graph()
@@ -208,7 +208,7 @@
comp.input_placeholders.append(
comp.graph.placeholder(x.name, type_expr=x.type)
)
- used_in_main.add(x)
+ used_in_main[x] = None
return comp.input_placeholders[
next(i for i, y in enumerate(comp.orig_inputs) if x is y)
@@ -231,7 +231,7 @@
else:
# All component results consumed by the output node should be
# marked as "used in main".
- used_in_main.add(x)
+ used_in_main[x] = None
# If a node is used in main graph then we mark it as an output in the component
# it belongs to.