[export] Serialize metadata (#103274)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103274
Approved by: https://github.com/zhxchen17
diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py
index 57dd40e..bdcc412 100644
--- a/test/export/test_serialize.py
+++ b/test/export/test_serialize.py
@@ -220,6 +220,30 @@
                 # For expressions like 's0 < 10' can only compare through string
                 self.assertEqual(str(val1), str(val2))
 
+            # Check "stack_trace" metadata
+            if "None" in node1.meta.get("stack_trace"):
+                self.assertTrue(
+                    node2.meta.get("stack_trace") is None
+                    or "None" in node2.meta.get("stack_trace")
+                )
+            else:
+                self.assertEqual(
+                    node1.meta.get("stack_trace", None),
+                    node2.meta.get("stack_trace", None),
+                )
+
+            # Check "nn_module_stack" metadata
+            self.assertEqual(
+                node1.meta.get("nn_module_stack", None),
+                node2.meta.get("nn_module_stack", None),
+            )
+
+            # Check "source_fn" metadata
+            self.assertEqual(
+                node1.meta.get("source_fn", None),
+                node2.meta.get("source_fn", None),
+            )
+
     def test_multi_return(self) -> None:
         """
         Test multiple return from a single node (ex. layer_norm has 2 outputs)
diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py
index f4f9110..67be23f 100644
--- a/torch/_export/serde/serialize.py
+++ b/torch/_export/serde/serialize.py
@@ -178,13 +178,20 @@
     ret = {}
     if stack_trace := node.meta.get("stack_trace"):
         ret["stack_trace"] = stack_trace
-    module_fqn = node.meta.get("module_fqn")
-    # Need an explicit None check instead of walrus operator, because
-    # module_fqn can be the empty string if the node belongs to the root.
-    # The walrus operator returns False on an empty string :(
-    if module_fqn is not None:
-        ret["module_fqn"] = module_fqn
-    # TODO(angelayi) add nn_module_stack and source_fn
+
+    if nn_module_stack := node.meta.get("nn_module_stack"):
+        # Serialize to "fx_node_name:(orig_ref,type_str)"
+        nn_module_list = [
+            f"{k}:({v[0]},{serialize_operator(v[1])})"
+            for k, v in nn_module_stack.items()
+        ]
+        ret["nn_module_stack"] = ";".join(nn_module_list)
+
+    if source_fn := node.meta.get("source_fn"):
+        # Serialize to "fx_node_name,op_str"
+        op = serialize_operator(source_fn[1])
+        ret["source_fn"] = f"{source_fn[0]},{op}"
+
     return ret
 
 
@@ -192,13 +199,47 @@
     ret = {}
     if stack_trace := metadata.get("stack_trace"):
         ret["stack_trace"] = stack_trace
-    # Need an explicit None check instead of walrus operator, because
-    # module_fqn can be the empty string if the node belongs to the root.
-    # The walrus operator returns False on an empty string :(
-    module_fqn = metadata.get("module_fqn")
-    if module_fqn is not None:
-        ret["module_fqn"] = module_fqn
-    # TODO(angelayi) add nn_module_stack and source_fn
+
+    def deserialize_meta_func(serialized_target: str):
+        module = None
+        if serialized_target.startswith("torch.nn"):
+            module = torch.nn
+            serialized_target_names = serialized_target.split(".")[2:]
+        elif serialized_target.startswith("torch"):
+            module = torch
+            serialized_target_names = serialized_target.split(".")[1:]
+        else:
+            return deserialize_operator(serialized_target)
+
+        target = module
+        for name in serialized_target_names:
+            if not hasattr(target, name):
+                return serialized_target
+            else:
+                target = getattr(target, name)
+        return target
+
+    if nn_module_stack_str := metadata.get("nn_module_stack"):
+        # Originally serialized to "fx_node_name:(orig_ref,type_str)"
+        nn_module_stack_list = nn_module_stack_str.split(";")
+        nn_module_stack = {}
+        for kv in nn_module_stack_list:
+            key_idx = kv.find(":")
+            key = kv[:key_idx]
+            assert kv[key_idx + 1] == "("
+            assert kv[-1] == ")"
+            values = kv[key_idx + 2: -1].split(",")
+            assert len(values) == 2
+            module = deserialize_meta_func(values[1])
+            nn_module_stack[key] = (values[0], module)
+        ret["nn_module_stack"] = nn_module_stack
+
+    if source_fn_str := metadata.get("source_fn"):
+        # Originally serializes to "fx_node_name,op_str"
+        source_fn = source_fn_str.split(",")
+        op = deserialize_meta_func(source_fn[1])
+        ret["source_fn"] = (source_fn[0], op)
+
     return ret