blob: 098ee9d774f9dea5383728b764bea2345af7bda4 [file] [log] [blame]
import dataclasses
from typing import Any, List, Optional, Tuple
from torch.utils._pytree import (
_register_pytree_node,
Context,
FlattenFunc,
MaybeFromStrFunc,
ToStrFunc,
UnflattenFunc,
)
def register_dataclass_as_pytree_node(
typ: Any,
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
to_str_fn: Optional[ToStrFunc] = None,
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
*,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
typ
), f"Only dataclasses can be registered with this function: {typ}"
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
flattened = []
flat_names = []
none_names = []
for f in dataclasses.fields(obj):
name, val = f.name, getattr(obj, f.name)
if val is not None or return_none_fields:
flattened.append(val)
flat_names.append(name)
else:
none_names.append(name)
return flattened, (typ, flat_names, none_names)
def default_unflatten_fn(values: List[Any], context: Context) -> Any:
typ, flat_names, none_names = context
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
_register_pytree_node(
typ,
flatten_fn,
unflatten_fn,
None,
None,
)