[export][reland] Fix unflattened submodule ordering. (#122341) (#122507)

Summary:

Make sure the order of submodules is the same as the original eager module.

bypass-github-export-checks

Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_unflatten_submodule_ordering

Differential Revision: D55251277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122507
Approved by: https://github.com/tugsbayasgalan
diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py
index 64f1231..30d0a79 100644
--- a/test/export/test_unflatten.py
+++ b/test/export/test_unflatten.py
@@ -615,6 +615,46 @@
 
         torch.export.unflatten(ep)
 
+    def test_unflatten_submodule_ordering(self):
+        class Module2(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.register_buffer("buffer", torch.rand(3, 4))
+                self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
+
+            def forward(self, x):
+                return x + self.buffer + self.param
+
+        class Module1(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.register_buffer("buffer", torch.rand(3, 4))
+                self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
+
+            def forward(self, x):
+                return x + self.buffer + self.param
+
+        class Module(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.mod2 = Module2()
+                self.mod3 = self.mod2
+                self.mod1 = Module1()
+
+            def forward(self, x):
+                return self.mod3(self.mod2(self.mod1(x)))
+
+        mod = Module()
+
+        ep = torch.export.export(mod, (torch.randn(3, 4),))
+
+        unflattened = torch.export.unflatten(ep)
+        fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
+        self.assertEqual(len(fqn_list), 4)
+        self.assertEqual(
+            [x for x, _ in mod.named_modules(remove_duplicate=False)],
+            fqn_list,
+        )
 
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/export/_trace.py b/torch/export/_trace.py
index 36e8b52..46acb37 100644
--- a/torch/export/_trace.py
+++ b/torch/export/_trace.py
@@ -315,7 +315,9 @@
 
 
 def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
-    return {name: type(m).__name__ for name, m in mod.named_modules()}
+    return {
+        name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
+    }
 
 
 def _make_module_call_graph(
diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py
index b16febd..49f6548 100644
--- a/torch/export/unflatten.py
+++ b/torch/export/unflatten.py
@@ -145,6 +145,8 @@
         if export_module.graph_signature.backward_signature is not None:
             raise ValueError("Unflattening on JointExportModule NYI")
 
+        fqn_list = [entry.fqn for entry in export_module.module_call_graph]
+        assert fqn_list[0] == ""
         export_graph = deepcopy(export_module.graph)
         self.graph_signature = deepcopy(export_module.graph_signature)
         self.graph = torch.fx.Graph()
@@ -224,7 +226,11 @@
             node for node in self.graph.nodes if node.op == "placeholder"
         ]
         self.check_input_constraints = True
-        assert self.module_call_graph[0].fqn == ""
+        # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
+        _reorder_submodules(self, {fqn: i for i, fqn in enumerate(fqn_list)})
+        assert [
+            fqn for fqn, _ in self.named_modules(remove_duplicate=False)
+        ] == fqn_list
 
     def forward(self, *args, **kwargs):
         signature = self.module_call_graph[0].signature
@@ -442,6 +448,23 @@
     return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
 
 
+def _get_submodule(mod: torch.nn.Module, target: str):
+    *prefix, field = target.split(".")
+
+    for item in prefix:
+        submod = getattr(mod, item, None)
+
+        if submod is None:
+            return None
+
+        if not isinstance(submod, torch.nn.Module):
+            return None
+
+        mod = submod
+
+    return getattr(mod, field, None)
+
+
 def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module):
     *prefix, field = target.split(".")
 
@@ -788,6 +811,28 @@
     ).run_outer()
 
 
+def _reorder_submodules(
+    parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = ""
+):
+    # TODO Can be optimized by adding submodules ahead of time.
+    if prefix == "":
+        for fqn in list(fqn_order.keys())[1:]:
+            if _get_submodule(parent, fqn) is None:
+                _add_submodule(parent, fqn, torch.nn.Module())
+
+    children = []
+    for name, child in list(parent._modules.items()):
+        if child is None:
+            continue
+        fqn = prefix + name
+        _reorder_submodules(child, fqn_order, prefix=fqn + ".")
+        delattr(parent, name)
+        children.append((fqn_order[fqn], name, child))
+    children.sort(key=lambda x: x[0])
+    for _, name, child in children:
+        parent.register_module(name, child)
+
+
 def _sink_params(
     module: torch.nn.Module,
     inputs_to_state: Dict[str, str],