[export] Serialize metadata (#103274)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103274
Approved by: https://github.com/zhxchen17
diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py
index 57dd40e..bdcc412 100644
--- a/test/export/test_serialize.py
+++ b/test/export/test_serialize.py
@@ -220,6 +220,30 @@
# For expressions like 's0 < 10' can only compare through string
self.assertEqual(str(val1), str(val2))
+ # Check "stack_trace" metadata
+ if "None" in node1.meta.get("stack_trace"):
+ self.assertTrue(
+ node2.meta.get("stack_trace") is None
+ or "None" in node2.meta.get("stack_trace")
+ )
+ else:
+ self.assertEqual(
+ node1.meta.get("stack_trace", None),
+ node2.meta.get("stack_trace", None),
+ )
+
+ # Check "nn_module_stack" metadata
+ self.assertEqual(
+ node1.meta.get("nn_module_stack", None),
+ node2.meta.get("nn_module_stack", None),
+ )
+
+ # Check "source_fn" metadata
+ self.assertEqual(
+ node1.meta.get("source_fn", None),
+ node2.meta.get("source_fn", None),
+ )
+
def test_multi_return(self) -> None:
"""
Test multiple return from a single node (ex. layer_norm has 2 outputs)
diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py
index f4f9110..67be23f 100644
--- a/torch/_export/serde/serialize.py
+++ b/torch/_export/serde/serialize.py
@@ -178,13 +178,20 @@
ret = {}
if stack_trace := node.meta.get("stack_trace"):
ret["stack_trace"] = stack_trace
- module_fqn = node.meta.get("module_fqn")
- # Need an explicit None check instead of walrus operator, because
- # module_fqn can be the empty string if the node belongs to the root.
- # The walrus operator returns False on an empty string :(
- if module_fqn is not None:
- ret["module_fqn"] = module_fqn
- # TODO(angelayi) add nn_module_stack and source_fn
+
+ if nn_module_stack := node.meta.get("nn_module_stack"):
+ # Serialize to "fx_node_name:(orig_ref,type_str)"
+ nn_module_list = [
+ f"{k}:({v[0]},{serialize_operator(v[1])})"
+ for k, v in nn_module_stack.items()
+ ]
+ ret["nn_module_stack"] = ";".join(nn_module_list)
+
+ if source_fn := node.meta.get("source_fn"):
+ # Serialize to "fx_node_name,op_str"
+ op = serialize_operator(source_fn[1])
+ ret["source_fn"] = f"{source_fn[0]},{op}"
+
return ret
@@ -192,13 +199,47 @@
ret = {}
if stack_trace := metadata.get("stack_trace"):
ret["stack_trace"] = stack_trace
- # Need an explicit None check instead of walrus operator, because
- # module_fqn can be the empty string if the node belongs to the root.
- # The walrus operator returns False on an empty string :(
- module_fqn = metadata.get("module_fqn")
- if module_fqn is not None:
- ret["module_fqn"] = module_fqn
- # TODO(angelayi) add nn_module_stack and source_fn
+
+ def deserialize_meta_func(serialized_target: str):
+ module = None
+ if serialized_target.startswith("torch.nn"):
+ module = torch.nn
+ serialized_target_names = serialized_target.split(".")[2:]
+ elif serialized_target.startswith("torch"):
+ module = torch
+ serialized_target_names = serialized_target.split(".")[1:]
+ else:
+ return deserialize_operator(serialized_target)
+
+ target = module
+ for name in serialized_target_names:
+ if not hasattr(target, name):
+ return serialized_target
+ else:
+ target = getattr(target, name)
+ return target
+
+ if nn_module_stack_str := metadata.get("nn_module_stack"):
+ # Originally serialized to "fx_node_name:(orig_ref,type_str)"
+ nn_module_stack_list = nn_module_stack_str.split(";")
+ nn_module_stack = {}
+ for kv in nn_module_stack_list:
+ key_idx = kv.find(":")
+ key = kv[:key_idx]
+ assert kv[key_idx + 1] == "("
+ assert kv[-1] == ")"
+ values = kv[key_idx + 2: -1].split(",")
+ assert len(values) == 2
+ module = deserialize_meta_func(values[1])
+ nn_module_stack[key] = (values[0], module)
+ ret["nn_module_stack"] = nn_module_stack
+
+ if source_fn_str := metadata.get("source_fn"):
+ # Originally serializes to "fx_node_name,op_str"
+ source_fn = source_fn_str.split(",")
+ op = deserialize_meta_func(source_fn[1])
+ ret["source_fn"] = (source_fn[0], op)
+
return ret