[export][reland] Fix unflattened submodule ordering. (#122341) (#122507)
Summary:
Make sure the order of submodules is the same as the original eager module.
bypass-github-export-checks
Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_unflatten_submodule_ordering
Differential Revision: D55251277
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122507
Approved by: https://github.com/tugsbayasgalan
diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py
index 64f1231..30d0a79 100644
--- a/test/export/test_unflatten.py
+++ b/test/export/test_unflatten.py
@@ -615,6 +615,46 @@
torch.export.unflatten(ep)
+ def test_unflatten_submodule_ordering(self):
+ class Module2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buffer", torch.rand(3, 4))
+ self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
+
+ def forward(self, x):
+ return x + self.buffer + self.param
+
+ class Module1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buffer", torch.rand(3, 4))
+ self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
+
+ def forward(self, x):
+ return x + self.buffer + self.param
+
+ class Module(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.mod2 = Module2()
+ self.mod3 = self.mod2
+ self.mod1 = Module1()
+
+ def forward(self, x):
+ return self.mod3(self.mod2(self.mod1(x)))
+
+ mod = Module()
+
+ ep = torch.export.export(mod, (torch.randn(3, 4),))
+
+ unflattened = torch.export.unflatten(ep)
+ fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
+ self.assertEqual(len(fqn_list), 4)
+ self.assertEqual(
+ [x for x, _ in mod.named_modules(remove_duplicate=False)],
+ fqn_list,
+ )
if __name__ == "__main__":
run_tests()
diff --git a/torch/export/_trace.py b/torch/export/_trace.py
index 36e8b52..46acb37 100644
--- a/torch/export/_trace.py
+++ b/torch/export/_trace.py
@@ -315,7 +315,9 @@
def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]:
- return {name: type(m).__name__ for name, m in mod.named_modules()}
+ return {
+ name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
+ }
def _make_module_call_graph(
diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py
index b16febd..49f6548 100644
--- a/torch/export/unflatten.py
+++ b/torch/export/unflatten.py
@@ -145,6 +145,8 @@
if export_module.graph_signature.backward_signature is not None:
raise ValueError("Unflattening on JointExportModule NYI")
+ fqn_list = [entry.fqn for entry in export_module.module_call_graph]
+ assert fqn_list[0] == ""
export_graph = deepcopy(export_module.graph)
self.graph_signature = deepcopy(export_module.graph_signature)
self.graph = torch.fx.Graph()
@@ -224,7 +226,11 @@
node for node in self.graph.nodes if node.op == "placeholder"
]
self.check_input_constraints = True
- assert self.module_call_graph[0].fqn == ""
+ # TODO(zhxchen17) We can register modules ahead of time instead of reorder later.
+ _reorder_submodules(self, {fqn: i for i, fqn in enumerate(fqn_list)})
+ assert [
+ fqn for fqn, _ in self.named_modules(remove_duplicate=False)
+ ] == fqn_list
def forward(self, *args, **kwargs):
signature = self.module_call_graph[0].signature
@@ -442,6 +448,23 @@
return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node))
+def _get_submodule(mod: torch.nn.Module, target: str):
+ *prefix, field = target.split(".")
+
+ for item in prefix:
+ submod = getattr(mod, item, None)
+
+ if submod is None:
+ return None
+
+ if not isinstance(submod, torch.nn.Module):
+ return None
+
+ mod = submod
+
+ return getattr(mod, field, None)
+
+
def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module):
*prefix, field = target.split(".")
@@ -788,6 +811,28 @@
).run_outer()
+def _reorder_submodules(
+ parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = ""
+):
+ # TODO Can be optimized by adding submodules ahead of time.
+ if prefix == "":
+ for fqn in list(fqn_order.keys())[1:]:
+ if _get_submodule(parent, fqn) is None:
+ _add_submodule(parent, fqn, torch.nn.Module())
+
+ children = []
+ for name, child in list(parent._modules.items()):
+ if child is None:
+ continue
+ fqn = prefix + name
+ _reorder_submodules(child, fqn_order, prefix=fqn + ".")
+ delattr(parent, name)
+ children.append((fqn_order[fqn], name, child))
+ children.sort(key=lambda x: x[0])
+ for _, name, child in children:
+ parent.register_module(name, child)
+
+
def _sink_params(
module: torch.nn.Module,
inputs_to_state: Dict[str, str],