Serialize empty pytree cases (#105159)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105159
Approved by: https://github.com/zhxchen17
diff --git a/test/test_pytree.py b/test/test_pytree.py
index 2608b55..4ec8576 100644
--- a/test/test_pytree.py
+++ b/test/test_pytree.py
@@ -275,6 +275,9 @@
self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
@parametrize("spec, str_spec", [
+ (TreeSpec(list, None, []), "L()"),
+ (TreeSpec(tuple, None, []), "T()"),
+ (TreeSpec(dict, [], []), "D()"),
(TreeSpec(list, None, [LeafSpec()]), "L(*)"),
(TreeSpec(list, None, [LeafSpec(), LeafSpec()]), "L(*,*)"),
(TreeSpec(tuple, None, [LeafSpec(), LeafSpec(), LeafSpec()]), "T(*,*,*)"),
diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py
index 7a56e88..3ecdc06 100644
--- a/torch/utils/_pytree.py
+++ b/torch/utils/_pytree.py
@@ -62,12 +62,16 @@
to_str_fn: Optional[ToStrFunc] = None,
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
) -> None:
- def _raise_error(_): # type: ignore[no-untyped-def]
- raise NotImplementedError(f"Serializing {typ} not implemented")
if to_str_fn is None:
- to_str_fn = _raise_error # type: ignore[assignment, return-value]
+ def _raise_error(spec: "TreeSpec", child_strings: List[str]) -> str:
+ raise NotImplementedError(f"Serializing {typ} not implemented")
+ to_str_fn = _raise_error
+
if maybe_from_str_fn is None:
- maybe_from_str_fn = _raise_error # type: ignore[assignment, return-value]
+ def dummy_to_str(str_spec: str) -> Optional[Tuple[Any, Context, str]]:
+ return None
+ maybe_from_str_fn = dummy_to_str
+
assert to_str_fn is not None
assert maybe_from_str_fn is not None
node_def = NodeDef(typ, flatten_fn, unflatten_fn, to_str_fn, maybe_from_str_fn)
@@ -484,10 +488,11 @@
res = node_def.maybe_from_str_fn(str_spec)
if res is not None:
typ, context, child_strings = res
- children_spec = [
- str_to_pytree(child_string)
- for child_string in _split_nested(child_strings)
- ]
+ children_spec = []
+ for child_string in _split_nested(child_strings):
+ if child_string == "":
+ continue
+ children_spec.append(str_to_pytree(child_string))
return TreeSpec(typ, context, children_spec)
raise NotImplementedError(f"Deserializing {str_spec} in pytree not supported yet")