[pytree] Fix namedtuple serialization (#123388)
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:
TreeSpec {
type='collections.namedtuple',
context={class_name='Point', class_fields={'x', 'y'}},
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:
TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""
spec == reconstructed_spec # False
```
So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
type='collections.namedtuple',
context='Point',
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
spec == reconstructed_spec # True
```
Test Plan: `python test/test_pytree.py`
Differential Revision: D55771058
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
diff --git a/test/test_pytree.py b/test/test_pytree.py
index cdfa476..caaf4d0 100644
--- a/test/test_pytree.py
+++ b/test/test_pytree.py
@@ -915,15 +915,44 @@
self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
- Point = namedtuple("Point", ["x", "y"])
- spec = py_pytree.TreeSpec(
- namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+ Point1 = namedtuple("Point1", ["x", "y"])
+ py_pytree._register_namedtuple(
+ Point1,
+ serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
)
+ spec = py_pytree.TreeSpec(
+ namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+ )
roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
- # The context in the namedtuple is different now because we recreated
- # the namedtuple type.
- self.assertEqual(spec.context._fields, roundtrip_spec.context._fields)
+ self.assertEqual(spec, roundtrip_spec)
+
+ class Point2(NamedTuple):
+ x: int
+ y: int
+
+ py_pytree._register_namedtuple(
+ Point2,
+ serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
+ )
+
+ spec = py_pytree.TreeSpec(
+ namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+ )
+ roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
+ self.assertEqual(spec, roundtrip_spec)
+
+ def test_pytree_serialize_namedtuple_bad(self):
+ DummyType = namedtuple("DummyType", ["x", "y"])
+
+ spec = py_pytree.TreeSpec(
+ namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
+ )
+
+ with self.assertRaisesRegex(
+ NotImplementedError, "Please register using `_register_namedtuple`"
+ ):
+ py_pytree.treespec_dumps(spec)
def test_pytree_custom_type_serialize_bad(self):
class DummyType:
@@ -1015,6 +1044,10 @@
spec = py_pytree.TreeSpec(
namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
)
+ py_pytree._register_namedtuple(
+ Point,
+ serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
+ )
with self.assertRaisesRegex(ValueError, "Unknown protocol"):
py_pytree.treespec_dumps(spec, -1)
@@ -1296,12 +1329,20 @@
self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
def test_pytree_serialize_namedtuple(self):
+ py_pytree._register_namedtuple(
+ GlobalPoint,
+ serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
+ )
spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
LocalPoint = namedtuple("LocalPoint", ["x", "y"])
+ py_pytree._register_namedtuple(
+ LocalPoint,
+ serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
+ )
spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py
index 861e887..a25d219 100644
--- a/torch/utils/_pytree.py
+++ b/torch/utils/_pytree.py
@@ -221,6 +221,34 @@
)
+def _register_namedtuple(
+ cls: Type[Any],
+ *,
+ serialized_type_name: str,
+) -> None:
+ """
+ Registers a namedtuple as a valid pytree node. By default namedtuples are
+ valid pytree nodes, but they are not serializable. This API provides the
+ argument `serialized_type_name` which allows these namedtuples to be
+ serialized.
+
+ Args:
+ cls: the dataclass type to register
+ serialized_type_name: The serialized name for the dataclass. This is
+ required if you want to serialize the pytree TreeSpec containing this
+ namedtuple.
+ """
+ _private_register_pytree_node(
+ cls,
+ _namedtuple_flatten,
+ _namedtuple_unflatten,
+ serialized_type_name=serialized_type_name,
+ to_dumpable_context=_namedtuple_serialize,
+ from_dumpable_context=_namedtuple_deserialize,
+ flatten_with_keys_fn=_namedtuple_flatten_with_keys,
+ )
+
+
def _register_pytree_node(
cls: Type[Any],
flatten_fn: FlattenFunc,
@@ -422,18 +450,34 @@
def _namedtuple_serialize(context: Context) -> DumpableContext:
- json_namedtuple = {
- "class_name": context.__name__,
- "fields": context._fields,
- }
- return json_namedtuple
+ if context not in SUPPORTED_SERIALIZED_TYPES:
+ raise NotImplementedError(
+ f"Can't serialize TreeSpec of namedtuple class {context} because we "
+ "didn't register a serializated_type_name. Please register using "
+ "`_register_namedtuple`."
+ )
+
+ serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
+ serialized_type_name = serialize_node_def.serialized_type_name
+
+ if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
+ raise NotImplementedError(
+ f"Can't serialize TreeSpec of namedtuple class {context} because we "
+ "couldn't find a serializated_type_name. Please register using "
+ "`_register_namedtuple`."
+ )
+ return serialized_type_name
def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
- class_name = dumpable_context["class_name"]
- assert isinstance(class_name, str)
- context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc]
- return context
+ if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
+ raise NotImplementedError(
+ f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
+ "because we couldn't find a serializated name."
+ )
+
+ typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
+ return typ
def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: