| import dataclasses |
| |
| from typing import Any, List, Optional, Tuple |
| |
| from torch.utils._pytree import ( |
| _register_pytree_node, |
| Context, |
| FlattenFunc, |
| MaybeFromStrFunc, |
| ToStrFunc, |
| UnflattenFunc, |
| ) |
| |
| |
| def register_dataclass_as_pytree_node( |
| typ: Any, |
| flatten_fn: Optional[FlattenFunc] = None, |
| unflatten_fn: Optional[UnflattenFunc] = None, |
| to_str_fn: Optional[ToStrFunc] = None, |
| maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, |
| *, |
| return_none_fields: bool = False, |
| ) -> None: |
| assert dataclasses.is_dataclass( |
| typ |
| ), f"Only dataclasses can be registered with this function: {typ}" |
| |
| def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: |
| flattened = [] |
| flat_names = [] |
| none_names = [] |
| for f in dataclasses.fields(obj): |
| name, val = f.name, getattr(obj, f.name) |
| if val is not None or return_none_fields: |
| flattened.append(val) |
| flat_names.append(name) |
| else: |
| none_names.append(name) |
| return flattened, (typ, flat_names, none_names) |
| |
| def default_unflatten_fn(values: List[Any], context: Context) -> Any: |
| typ, flat_names, none_names = context |
| return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names}) |
| |
| flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn |
| unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn |
| |
| _register_pytree_node( |
| typ, |
| flatten_fn, |
| unflatten_fn, |
| None, |
| None, |
| ) |