Recursively print graph module and its submodule (#81080)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080
Approved by: https://github.com/ezyang
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 36daa56..6628b63 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -706,9 +706,27 @@
     def __copy__(self):
         return GraphModule(self, self.graph)
 
+    def __nested_code(self) -> str:
+        """
+        Return the Python code generated for current GraphModule and its children GraphModules
+        """
+        module_code = self.code
+        module_code = module_code.lstrip('\n')
+        module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
+        module_code = _addindent(module_code, 4)
+
+        submodule_code_list = [""]
+        for submodule in self.children():
+            if isinstance(submodule, GraphModule):
+                submodule_code_list.append(submodule.__nested_code())
+        submodule_code = "\n".join(submodule_code_list)
+        submodule_code = _addindent(submodule_code, 4)
+
+        return module_code + submodule_code
+
     def __str__(self) -> str:
         orig_str = super().__str__()
-        return '\n'.join([orig_str, self._code])
+        return '\n'.join([orig_str, self.__nested_code()])
 
     def _replicate_for_data_parallel(self):
         new_gm = self.__copy__()
diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py
index aa626f7..f3d5f02 100644
--- a/torch/fx/passes/utils/fuser_utils.py
+++ b/torch/fx/passes/utils/fuser_utils.py
@@ -167,13 +167,15 @@
 
     return fused_gm, original_inputs, original_outputs
 
+
 def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
-    # assign sub_gm into gm
-    setattr(gm, sub_gm.__class__.__name__, sub_gm)
+    # add sub_gm into gm
+    submodule_name = sub_gm.__class__.__name__
+    gm.add_submodule(submodule_name, sub_gm)
 
     # Create a call_module node in main graph.
     module_node = gm.graph.call_module(
-        sub_gm.__class__.__name__,
+        submodule_name,
         args=orig_inputs,
         kwargs=None)
 
@@ -185,8 +187,6 @@
             # Use Proxy to record getitem access.
             proxy_out = torch.fx.Proxy(module_node)[i].node  # type: ignore[index]
             orig_output.replace_all_uses_with(proxy_out)
-
-
     return gm
 
 def erase_nodes(gm: GraphModule, nodes: NodeList):
@@ -196,7 +196,6 @@
         gm.graph.erase_node(node)
 
 
-
 def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
     for partition_id, nodes in enumerate(partitions):
         sorted_nodes = topo_sort(nodes)