[pytree] Fix namedtuple serialization (#123388)

Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:

TreeSpec {
  type='collections.namedtuple',
  context={class_name='Point', class_fields={'x', 'y'}},
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:

TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""

spec == reconstructed_spec  # False
```

So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")

p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
  type='collections.namedtuple',
  context='Point',
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

spec == reconstructed_spec  # True
```

Test Plan: `python test/test_pytree.py`

Differential Revision: D55771058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
diff --git a/test/test_pytree.py b/test/test_pytree.py
index cdfa476..caaf4d0 100644
--- a/test/test_pytree.py
+++ b/test/test_pytree.py
@@ -915,15 +915,44 @@
         self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
 
     def test_pytree_serialize_namedtuple(self):
-        Point = namedtuple("Point", ["x", "y"])
-        spec = py_pytree.TreeSpec(
-            namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+        Point1 = namedtuple("Point1", ["x", "y"])
+        py_pytree._register_namedtuple(
+            Point1,
+            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
         )
 
+        spec = py_pytree.TreeSpec(
+            namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+        )
         roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
-        # The context in the namedtuple is different now because we recreated
-        # the namedtuple type.
-        self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
+        self.assertEqual(spec, roundtrip_spec)
+
+        class Point2(NamedTuple):
+            x: int
+            y: int
+
+        py_pytree._register_namedtuple(
+            Point2,
+            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
+        )
+
+        spec = py_pytree.TreeSpec(
+            namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+        )
+        roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
+        self.assertEqual(spec, roundtrip_spec)
+
+    def test_pytree_serialize_namedtuple_bad(self):
+        DummyType = namedtuple("DummyType", ["x", "y"])
+
+        spec = py_pytree.TreeSpec(
+            namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+        )
+
+        with self.assertRaisesRegex(
+            NotImplementedError, "Please register using `_register_namedtuple`"
+        ):
+            py_pytree.treespec_dumps(spec)
 
     def test_pytree_custom_type_serialize_bad(self):
         class DummyType:
@@ -1015,6 +1044,10 @@
         spec = py_pytree.TreeSpec(
             namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
         )
+        py_pytree._register_namedtuple(
+            Point,
+            serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
+        )
 
         with self.assertRaisesRegex(ValueError, "Unknown protocol"):
             py_pytree.treespec_dumps(spec, -1)
@@ -1296,12 +1329,20 @@
         self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
 
     def test_pytree_serialize_namedtuple(self):
+        py_pytree._register_namedtuple(
+            GlobalPoint,
+            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
+        )
         spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
 
         roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
         self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
 
         LocalPoint = namedtuple("LocalPoint", ["x", "y"])
+        py_pytree._register_namedtuple(
+            LocalPoint,
+            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
+        )
         spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
 
         roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py
index 861e887..a25d219 100644
--- a/torch/utils/_pytree.py
+++ b/torch/utils/_pytree.py
@@ -221,6 +221,34 @@
         )
 
 
+def _register_namedtuple(
+    cls: Type[Any],
+    *,
+    serialized_type_name: str,
+) -> None:
+    """
+    Registers a namedtuple as a valid pytree node. By default namedtuples are
+    valid pytree nodes, but they are not serializable. This API provides the
+    argument `serialized_type_name` which allows these namedtuples to be
+    serialized.
+
+    Args:
+        cls: the dataclass type to register
+        serialized_type_name: The serialized name for the dataclass. This is
+        required if you want to serialize the pytree TreeSpec containing this
+        namedtuple.
+    """
+    _private_register_pytree_node(
+        cls,
+        _namedtuple_flatten,
+        _namedtuple_unflatten,
+        serialized_type_name=serialized_type_name,
+        to_dumpable_context=_namedtuple_serialize,
+        from_dumpable_context=_namedtuple_deserialize,
+        flatten_with_keys_fn=_namedtuple_flatten_with_keys,
+    )
+
+
 def _register_pytree_node(
     cls: Type[Any],
     flatten_fn: FlattenFunc,
@@ -422,18 +450,34 @@
 
 
 def _namedtuple_serialize(context: Context) -> DumpableContext:
-    json_namedtuple = {
-        "class_name": context.__name__,
-        "fields": context._fields,
-    }
-    return json_namedtuple
+    if context not in SUPPORTED_SERIALIZED_TYPES:
+        raise NotImplementedError(
+            f"Can't serialize TreeSpec of namedtuple class {context} because we "
+            "didn't register a serializated_type_name. Please register using "
+            "`_register_namedtuple`."
+        )
+
+    serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
+    serialized_type_name = serialize_node_def.serialized_type_name
+
+    if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
+        raise NotImplementedError(
+            f"Can't serialize TreeSpec of namedtuple class {context} because we "
+            "couldn't find a serializated_type_name. Please register using "
+            "`_register_namedtuple`."
+        )
+    return serialized_type_name
 
 
 def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
-    class_name = dumpable_context["class_name"]
-    assert isinstance(class_name, str)
-    context = namedtuple(class_name, dumpable_context["fields"])  # type: ignore[misc]
-    return context
+    if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
+        raise NotImplementedError(
+            f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
+            "because we couldn't find a serializated name."
+        )
+
+    typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
+    return typ
 
 
 def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: