Recursively print graph module and its submodule (#81080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080
Approved by: https://github.com/ezyang
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 36daa56..6628b63 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -706,9 +706,27 @@
def __copy__(self):
return GraphModule(self, self.graph)
+ def __nested_code(self) -> str:
+ """
+ Return the Python code generated for current GraphModule and its children GraphModules
+ """
+ module_code = self.code
+ module_code = module_code.lstrip('\n')
+ module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
+ module_code = _addindent(module_code, 4)
+
+ submodule_code_list = [""]
+ for submodule in self.children():
+ if isinstance(submodule, GraphModule):
+ submodule_code_list.append(submodule.__nested_code())
+ submodule_code = "\n".join(submodule_code_list)
+ submodule_code = _addindent(submodule_code, 4)
+
+ return module_code + submodule_code
+
def __str__(self) -> str:
orig_str = super().__str__()
- return '\n'.join([orig_str, self._code])
+ return '\n'.join([orig_str, self.__nested_code()])
def _replicate_for_data_parallel(self):
new_gm = self.__copy__()
diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py
index aa626f7..f3d5f02 100644
--- a/torch/fx/passes/utils/fuser_utils.py
+++ b/torch/fx/passes/utils/fuser_utils.py
@@ -167,13 +167,15 @@
return fused_gm, original_inputs, original_outputs
+
def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
- # assign sub_gm into gm
- setattr(gm, sub_gm.__class__.__name__, sub_gm)
+ # add sub_gm into gm
+ submodule_name = sub_gm.__class__.__name__
+ gm.add_submodule(submodule_name, sub_gm)
# Create a call_module node in main graph.
module_node = gm.graph.call_module(
- sub_gm.__class__.__name__,
+ submodule_name,
args=orig_inputs,
kwargs=None)
@@ -185,8 +187,6 @@
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out)
-
-
return gm
def erase_nodes(gm: GraphModule, nodes: NodeList):
@@ -196,7 +196,6 @@
gm.graph.erase_node(node)
-
def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
for partition_id, nodes in enumerate(partitions):
sorted_nodes = topo_sort(nodes)