Serialize empty pytree cases (#105159)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105159
Approved by: https://github.com/zhxchen17
diff --git a/test/test_pytree.py b/test/test_pytree.py
index 2608b55..4ec8576 100644
--- a/test/test_pytree.py
+++ b/test/test_pytree.py
@@ -275,6 +275,9 @@
             self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
 
     @parametrize("spec, str_spec", [
+        (TreeSpec(list, None, []), "L()"),
+        (TreeSpec(tuple, None, []), "T()"),
+        (TreeSpec(dict, [], []), "D()"),
         (TreeSpec(list, None, [LeafSpec()]), "L(*)"),
         (TreeSpec(list, None, [LeafSpec(), LeafSpec()]), "L(*,*)"),
         (TreeSpec(tuple, None, [LeafSpec(), LeafSpec(), LeafSpec()]), "T(*,*,*)"),
diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py
index 7a56e88..3ecdc06 100644
--- a/torch/utils/_pytree.py
+++ b/torch/utils/_pytree.py
@@ -62,12 +62,16 @@
     to_str_fn: Optional[ToStrFunc] = None,
     maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
 ) -> None:
-    def _raise_error(_):  # type: ignore[no-untyped-def]
-        raise NotImplementedError(f"Serializing {typ} not implemented")
     if to_str_fn is None:
-        to_str_fn = _raise_error   # type: ignore[assignment, return-value]
+        def _raise_error(spec: "TreeSpec", child_strings: List[str]) -> str:
+            raise NotImplementedError(f"Serializing {typ} not implemented")
+        to_str_fn = _raise_error
+
     if maybe_from_str_fn is None:
-        maybe_from_str_fn = _raise_error  # type: ignore[assignment, return-value]
+        def dummy_to_str(str_spec: str) -> Optional[Tuple[Any, Context, str]]:
+            return None
+        maybe_from_str_fn = dummy_to_str
+
     assert to_str_fn is not None
     assert maybe_from_str_fn is not None
     node_def = NodeDef(typ, flatten_fn, unflatten_fn, to_str_fn, maybe_from_str_fn)
@@ -484,10 +488,11 @@
         res = node_def.maybe_from_str_fn(str_spec)
         if res is not None:
             typ, context, child_strings = res
-            children_spec = [
-                str_to_pytree(child_string)
-                for child_string in _split_nested(child_strings)
-            ]
+            children_spec = []
+            for child_string in _split_nested(child_strings):
+                if child_string == "":
+                    continue
+                children_spec.append(str_to_pytree(child_string))
             return TreeSpec(typ, context, children_spec)
     raise NotImplementedError(f"Deserializing {str_spec} in pytree not supported yet")