| import contextlib | 
 | import functools | 
 | import gc | 
 | from dataclasses import asdict, dataclass, field | 
 | from itertools import chain | 
 | from typing import ( | 
 |     Any, | 
 |     Callable, | 
 |     cast, | 
 |     Dict, | 
 |     Generator, | 
 |     Iterable, | 
 |     List, | 
 |     no_type_check, | 
 |     Optional, | 
 |     Set, | 
 |     Tuple, | 
 |     Union, | 
 | ) | 
 |  | 
 | import torch | 
 | import torch.distributed as dist | 
 | import torch.nn as nn | 
 | from torch.distributed._shard.sharded_tensor import ShardedTensor | 
 | from torch.distributed._state_dict_utils import ( | 
 |     _broadcast_state_dict, | 
 |     _flatten_state_dict, | 
 |     _gather_state_dict, | 
 |     _offload_state_dict_to_cpu, | 
 |     _unflatten_state_dict, | 
 | ) | 
 | from torch.distributed._tensor import DTensor | 
 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | 
 |     _CHECKPOINT_PREFIX, | 
 | ) | 
 | from torch.distributed.fsdp import ( | 
 |     FullOptimStateDictConfig, | 
 |     FullStateDictConfig, | 
 |     FullyShardedDataParallel as FSDP, | 
 |     OptimStateDictConfig, | 
 |     ShardedOptimStateDictConfig, | 
 |     ShardedStateDictConfig, | 
 |     StateDictConfig, | 
 |     StateDictType, | 
 | ) | 
 | from torch.distributed.fsdp._common_utils import ( | 
 |     _get_module_fsdp_state_if_fully_sharded_module, | 
 |     FSDP_WRAPPED_MODULE, | 
 | ) | 
 | from torch.nn.modules.module import _IncompatibleKeys | 
 | from torch.nn.parallel import DistributedDataParallel as DDP | 
 | from torch.utils._pytree import tree_map_only | 
 |  | 
 | __all__ = [ | 
 |     "FLAT_PARAM", | 
 |     "PG", | 
 |     "PG_PREFIX", | 
 |     "STATE", | 
 |     "STATE_PREFIX", | 
 |     "PARAMS", | 
 |     "FQNS_T", | 
 |     "PrimitiveType", | 
 |     "ValueType", | 
 |     "DictValueType", | 
 |     "ListDictValueType", | 
 |     "OptimizerStateType", | 
 |     "gc_context", | 
 |     "StateDictOptions", | 
 |     "get_model_state_dict", | 
 |     "get_optimizer_state_dict", | 
 |     "get_state_dict", | 
 |     "set_model_state_dict", | 
 |     "set_optimizer_state_dict", | 
 |     "set_state_dict", | 
 | ] | 
 |  | 
 | FLAT_PARAM = "_flat_param" | 
 | PG = "param_groups" | 
 | PG_PREFIX = f"{PG}." | 
 | STATE = "state" | 
 | STATE_PREFIX = f"{STATE}." | 
 | PARAMS = "params" | 
 | FQNS_T = Set[str] | 
 |  | 
 | _patched_state_dict: Set[Callable] = set() | 
 |  | 
 |  | 
 | PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] | 
 | ValueType = Union[ | 
 |     PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"] | 
 | ] | 
 | DictValueType = Dict[str, ValueType] | 
 | ListDictValueType = List[DictValueType] | 
 | OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]] | 
 |  | 
 |  | 
 | @contextlib.contextmanager | 
 | def gc_context(): | 
 |     is_enabled = gc.isenabled() | 
 |     gc.disable() | 
 |     try: | 
 |         yield | 
 |     finally: | 
 |         # TODO: add logging for the gc details/time | 
 |         if is_enabled: | 
 |             gc.enable() | 
 |  | 
 |  | 
 | @dataclass | 
 | class StateDictOptions: | 
 |     """ | 
 |     This dataclass specifies how get_state_dict/set_state_dict will work. | 
 |  | 
 |     - ``full_state_dict``: if this is set to True, all the tensors in the | 
 |       returned state_dict will be gathered. No ShardedTensor and DTensor | 
 |       will be in the returned state_dict. | 
 |  | 
 |     - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if | 
 |       ``full_state_dict`` is also true, then only the rank0 will get the | 
 |       state_dict and all other ranks will get empty state_dict. | 
 |  | 
 |     - ``ignore_frozen_params``: if the value is True, the returned state_dict | 
 |       won't contain any frozen parameters -- the ``requires_grad`` is False. | 
 |       The default value is False. | 
 |  | 
 |     - ``keep_submodule_prefixes``: when ``submodules`` is not None, this option | 
 |       indicates whether to keep the submodule prefixes from the state_dict keys. | 
 |       or example, if the submodule is ``module.pretrain`` and the full FQN of | 
 |       the parameter is ``pretrain.layer1.weight`` of the param. When this option | 
 |       is True, the parameter's key in the returned state_dict will be | 
 |       ``pretrain.layer1.weight``. If the options is False, the key will be | 
 |       ``layer1.weight``. | 
 |       Note that if ``keep_submodule_prefixes`` is False, there may be conflicted | 
 |       FQNs, hence there should be only one submodule in ``submodules``. | 
 |  | 
 |     - ``strict``: the ``strict`` option when ``set_state_dict`` calls | 
 |       model.load_state_dict(). | 
 |       The default value is False. | 
 |  | 
 |     - ``broadcast_from_rank0``: when the option is True, rank0 should receive a | 
 |        full state_dict and will broadcast the tensors in the state_dict/ | 
 |        optim_state_dict one by one to other ranks. Other ranks will receive | 
 |        the tensors and shard according to the local shards in the model and | 
 |        optimizer. ``full_state_dict`` must be set to True when using this option. | 
 |        This option currently only supports DTensor, not the legacy ShardedTensor. | 
 |     """ | 
 |  | 
 |     full_state_dict: bool = False | 
 |     cpu_offload: bool = False | 
 |     ignore_frozen_params: bool = False | 
 |     keep_submodule_prefixes: bool = True | 
 |     strict: bool = True | 
 |     broadcast_from_rank0: bool = False | 
 |  | 
 |  | 
 | @dataclass | 
 | class _StateDictInfo(StateDictOptions): | 
 |     fqn_param_mapping: Dict[ | 
 |         Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] | 
 |     ] = field(default_factory=dict) | 
 |     all_fqns: Set[str] = field(default_factory=set) | 
 |     submodule_prefixes: Set[str] = field(default_factory=set) | 
 |     handle_model: bool = True | 
 |     handle_optim: bool = True | 
 |     fsdp_context: Callable = contextlib.nullcontext | 
 |     fsdp_modules: List[nn.Module] = field(default_factory=list) | 
 |  | 
 |  | 
 | @functools.lru_cache(maxsize=None) | 
 | def _get_fqns( | 
 |     model: nn.Module, | 
 |     name: str, | 
 |     skip_ddp_prefix: bool = True, | 
 |     skip_compiler_prefix: bool = True, | 
 | ) -> FQNS_T: | 
 |     """ | 
 |     This API is used to convert the name of a parameter to the FQNs. For FSDP | 
 |     without `use_orig_params`, the name of FlatParameter can be mapped to | 
 |     multiple original parameters. As a result, the return type of this function | 
 |     is `Set[str]`. | 
 |  | 
 |     Args: | 
 |         module (nn.Module): the root model. | 
 |         name (str): the name | 
 |         skip_ddp_prefix (bool): whether to skip DDP's `module` prefix | 
 |  | 
 |     Returns: | 
 |         The canonical FQNs based on the model traversal. | 
 |     """ | 
 |  | 
 |     # Remove the checkpoint prefix, if it exists. | 
 |     name = name.replace(_CHECKPOINT_PREFIX, "") | 
 |     if "." not in name: | 
 |         return {name} | 
 |  | 
 |     obj_names = name.split(".") | 
 |     fqn_obj_names = [] | 
 |     curr_obj = model | 
 |     for i, curr_obj_name in enumerate(obj_names): | 
 |         if isinstance(curr_obj, DDP): | 
 |             assert curr_obj_name == "module" | 
 |             curr_obj = curr_obj.module | 
 |             if not skip_ddp_prefix: | 
 |                 fqn_obj_names.append(curr_obj_name) | 
 |         elif isinstance(curr_obj, FSDP): | 
 |             if i < len(obj_names) - 1 and obj_names[i + 1] == FLAT_PARAM: | 
 |                 prefix = ".".join(fqn_obj_names) | 
 |                 flat_param = getattr(curr_obj, FLAT_PARAM) | 
 |                 if prefix: | 
 |                     prefix = f"{prefix}." | 
 |                 return {f"{prefix}{fqn}" for fqn in flat_param._fqns} | 
 |             curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) | 
 |             if curr_obj_name != FSDP_WRAPPED_MODULE: | 
 |                 fqn_obj_names.append(curr_obj_name) | 
 |                 curr_obj = getattr(curr_obj, curr_obj_name) | 
 |         elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): | 
 |             assert curr_obj_name == "_orig_mod" | 
 |             curr_obj = curr_obj._orig_mod | 
 |             if not skip_compiler_prefix: | 
 |                 fqn_obj_names.append(curr_obj_name) | 
 |         else: | 
 |             fqn_obj_names.append(curr_obj_name) | 
 |             if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: | 
 |                 if i != len(obj_names) - 1: | 
 |                     raise RuntimeError("Expect `_extra_state` to be the last obj name") | 
 |             else: | 
 |                 curr_obj = getattr(curr_obj, curr_obj_name) | 
 |  | 
 |     return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} | 
 |  | 
 |  | 
 | class _EXTRA_STATE: | 
 |     pass | 
 |  | 
 |  | 
 | def _iterate_valid_model_state(model): | 
 |     visited_modules: Set[nn.Module] = set() | 
 |  | 
 |     def recurse(module: nn.Module, curr_fqn: str) -> Generator: | 
 |         visited_modules.add(module) | 
 |  | 
 |         curr_fqn = f"{curr_fqn}." if curr_fqn else "" | 
 |         for name, submodule in module.named_children(): | 
 |             if submodule in visited_modules: | 
 |                 continue | 
 |             new_fqn = f"{curr_fqn}{name}" | 
 |             yield from recurse(submodule, new_fqn) | 
 |  | 
 |         for name, obj in chain( | 
 |             module.named_buffers(recurse=False), module.named_parameters(recurse=False) | 
 |         ): | 
 |             if name in module._non_persistent_buffers_set: | 
 |                 continue | 
 |             new_fqn = f"{curr_fqn}{name}" | 
 |             yield new_fqn, obj | 
 |  | 
 |         if ( | 
 |             getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) | 
 |             != nn.Module.get_extra_state | 
 |         ): | 
 |             new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" | 
 |             yield new_fqn, _EXTRA_STATE() | 
 |  | 
 |     yield from recurse(model, "") | 
 |  | 
 |  | 
 | def _verify_options( | 
 |     model: nn.Module, | 
 |     optims: Tuple[torch.optim.Optimizer, ...], | 
 |     optim_only: bool, | 
 |     *, | 
 |     submodules: Optional[Set[nn.Module]] = None, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> _StateDictInfo: | 
 |     """ | 
 |     Verify the model and options passed by the user and generates _StateDictInfo. | 
 |     """ | 
 |     if optim_only and not optims: | 
 |         raise RuntimeError( | 
 |             "Optimizers are not passed in but optim_only is set to True." | 
 |         ) | 
 |  | 
 |     options = options or StateDictOptions() | 
 |  | 
 |     fqn_param_mapping: Dict[ | 
 |         Union[str, torch.Tensor], Union[Set[str], torch.Tensor] | 
 |     ] = {} | 
 |     all_fqns = set() | 
 |     for name, param in _iterate_valid_model_state(model): | 
 |         fqns = _get_fqns(model, name) | 
 |         if not isinstance(param, _EXTRA_STATE): | 
 |             fqn_param_mapping[param] = fqns | 
 |         for fqn in fqns: | 
 |             if not isinstance(param, _EXTRA_STATE): | 
 |                 fqn_param_mapping[fqn] = param | 
 |             all_fqns.add(fqn) | 
 |  | 
 |     submodule_prefixes: Set[str] = set() | 
 |     if submodules: | 
 |         submodules = set(submodules) | 
 |         for name, module in model.named_modules(): | 
 |             if module not in submodules: | 
 |                 continue | 
 |             fqns = _get_fqns(model, name) | 
 |             assert len(fqns) == 1, "Submodule FQN should only have 1 instance" | 
 |             submodule_prefixes.update(f"{fqn}." for fqn in fqns) | 
 |  | 
 |     if options.broadcast_from_rank0 and not options.full_state_dict: | 
 |         raise ValueError( | 
 |             "full_state_dict must be True when broadcast_from_rank0 is True." | 
 |         ) | 
 |     fsdp_modules = FSDP.fsdp_modules(model) | 
 |     state_dict_config: StateDictConfig | 
 |     optim_state_dict_config: OptimStateDictConfig | 
 |     fsdp_context: Callable | 
 |     if fsdp_modules: | 
 |         # FSDP API only work if at least one FSDP instance exists. | 
 |         if options.full_state_dict: | 
 |             state_dict_config = FullStateDictConfig( | 
 |                 offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload | 
 |             ) | 
 |             optim_state_dict_config = FullOptimStateDictConfig( | 
 |                 offload_to_cpu=options.cpu_offload, | 
 |                 rank0_only=(options.cpu_offload or options.broadcast_from_rank0), | 
 |             ) | 
 |             state_dict_type = StateDictType.FULL_STATE_DICT | 
 |         else: | 
 |             state_dict_config = ShardedStateDictConfig( | 
 |                 offload_to_cpu=options.cpu_offload, | 
 |             ) | 
 |             optim_state_dict_config = ShardedOptimStateDictConfig( | 
 |                 offload_to_cpu=options.cpu_offload, | 
 |             ) | 
 |             state_dict_type = StateDictType.SHARDED_STATE_DICT | 
 |  | 
 |         fsdp_context = functools.partial( | 
 |             FSDP.state_dict_type, | 
 |             module=model, | 
 |             state_dict_type=state_dict_type, | 
 |             state_dict_config=state_dict_config, | 
 |             optim_state_dict_config=optim_state_dict_config, | 
 |         ) | 
 |     else: | 
 |         fsdp_context = contextlib.nullcontext | 
 |  | 
 |     return _StateDictInfo( | 
 |         **asdict(options), | 
 |         fqn_param_mapping=fqn_param_mapping, | 
 |         all_fqns=all_fqns, | 
 |         submodule_prefixes=submodule_prefixes, | 
 |         fsdp_context=fsdp_context, | 
 |         fsdp_modules=cast(List[nn.Module], fsdp_modules), | 
 |         handle_model=not optim_only, | 
 |         handle_optim=(len(optims) > 0), | 
 |     ) | 
 |  | 
 |  | 
 | def _verify_state_dict( | 
 |     model_state_dict: Dict[str, ValueType], | 
 |     optim_state_dict: OptimizerStateType, | 
 |     info: _StateDictInfo, | 
 | ) -> None: | 
 |     for module in info.fsdp_modules: | 
 |         fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) | 
 |         assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module." | 
 |  | 
 |     # Verify if the model_state_dict and optim_state_dict are valid. This API | 
 |     # should give the users an explicit error message to debug or report. | 
 |     if ( | 
 |         info.handle_model | 
 |         and not model_state_dict | 
 |         and not info.submodule_prefixes | 
 |         and not info.ignore_frozen_params | 
 |         and not (info.cpu_offload and info.full_state_dict) | 
 |         and info.strict | 
 |         and not info.broadcast_from_rank0 | 
 |     ): | 
 |         raise RuntimeError( | 
 |             "The option indicates that model state_dict is required to save " | 
 |             "or load, but model state_dict is empty." | 
 |             f"rank = {dist.get_rank()=}." | 
 |         ) | 
 |  | 
 |     if info.handle_optim: | 
 |         if ( | 
 |             not (optim_state_dict and optim_state_dict[STATE]) | 
 |             and not (info.cpu_offload and info.full_state_dict) | 
 |             and (not info.broadcast_from_rank0) | 
 |         ): | 
 |             raise RuntimeError( | 
 |                 "The option indicates that model state_dict is required to save, " | 
 |                 f"or load but optim state_dict is empty. {optim_state_dict}" | 
 |             ) | 
 |  | 
 |     for key in model_state_dict.keys(): | 
 |         if FLAT_PARAM in key: | 
 |             raise RuntimeError( | 
 |                 f"{key} contains {FLAT_PARAM}. This can happen if the model " | 
 |                 "is not the root module." | 
 |             ) | 
 |  | 
 |  | 
 | def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable: | 
 |     call = getattr(obj, api) | 
 |     if call in _patched_state_dict: | 
 |         call = functools.partial(getattr(obj.__class__, api), self=obj) | 
 |     return call | 
 |  | 
 |  | 
 | def _get_model_state_dict( | 
 |     model: nn.Module, info: _StateDictInfo | 
 | ) -> Dict[str, ValueType]: | 
 |     if not info.handle_model: | 
 |         return {} | 
 |  | 
 |     with info.fsdp_context(): | 
 |         state_dict = _state_dict_fn(model, "state_dict")() | 
 |  | 
 |     for key in list(state_dict.keys()): | 
 |         fqns = _get_fqns(model, key) | 
 |         assert len(fqns) == 1 | 
 |         fqn = next(iter(fqns)) | 
 |         if fqn != key: | 
 |             # As we only support FSDP, DDP, and TP, the only cases are | 
 |             # wrapper-based DDP and compiler. Verify if the assumption | 
 |             # is correct. | 
 |             def verify(key, fqn) -> bool: | 
 |                 if len(fqn) >= len(key): | 
 |                     return False | 
 |                 fqn_split = fqn.split(".") | 
 |                 key_split = key.split(".") | 
 |                 fqn_idx = 0 | 
 |                 for key_idx, key_name in enumerate(key_split): | 
 |                     if key_name == fqn_split[fqn_idx]: | 
 |                         fqn_idx += 1 | 
 |                         if fqn_idx == len(fqn_split): | 
 |                             return key_idx == len(key_split) - 1 | 
 |                     elif key_name in ("module", "_orig_mod"): | 
 |                         continue | 
 |                     else: | 
 |                         return False | 
 |                 return True | 
 |  | 
 |             if not verify(key, fqn): | 
 |                 raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") | 
 |             state_dict[fqn] = state_dict.pop(key) | 
 |  | 
 |     if info.submodule_prefixes: | 
 |         new_state_dict: Dict[str, ValueType] = {} | 
 |         # TODO: make this faster. | 
 |         for fqn in state_dict.keys(): | 
 |             for prefix in info.submodule_prefixes: | 
 |                 if not fqn.startswith(prefix): | 
 |                     continue | 
 |                 if info.keep_submodule_prefixes: | 
 |                     new_state_dict[fqn] = state_dict[fqn] | 
 |                 else: | 
 |                     new_fqn = fqn[len(prefix) :] | 
 |                     new_state_dict[new_fqn] = state_dict[fqn] | 
 |         state_dict = new_state_dict | 
 |  | 
 |     if info.ignore_frozen_params: | 
 |         for key, param in model.named_parameters(): | 
 |             if param.requires_grad: | 
 |                 continue | 
 |             fqns = _get_fqns(model, key) | 
 |             for fqn in fqns: | 
 |                 state_dict.pop(fqn) | 
 |  | 
 |     for key, p in list(state_dict.items()): | 
 |         if torch.is_tensor(p) and p.is_meta: | 
 |             state_dict.pop(key) | 
 |  | 
 |     if info.full_state_dict: | 
 |         ranks_only = tuple() if not info.cpu_offload else (0,) | 
 |         return _gather_state_dict( | 
 |             state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only | 
 |         ) | 
 |     elif info.cpu_offload: | 
 |         return _offload_state_dict_to_cpu(state_dict) | 
 |     else: | 
 |         return state_dict | 
 |  | 
 |  | 
 | def _load_model_state_dict( | 
 |     model: nn.Module, | 
 |     state_dict: Dict[str, ValueType], | 
 |     info: _StateDictInfo, | 
 | ) -> _IncompatibleKeys: | 
 |     if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): | 
 |         return _IncompatibleKeys({}, {}) | 
 |  | 
 |     local_state_dict = {} | 
 |     for key, value in _iterate_valid_model_state(model): | 
 |         fqns = _get_fqns(model, key) | 
 |         fqns_with_prefix = _get_fqns( | 
 |             model, key, skip_ddp_prefix=False, skip_compiler_prefix=False | 
 |         ) | 
 |  | 
 |         for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix): | 
 |             if ( | 
 |                 not info.broadcast_from_rank0 or dist.get_rank() == 0 | 
 |             ) and fqn != fqn_with_prefix: | 
 |                 state_dict[fqn_with_prefix] = state_dict.pop(fqn) | 
 |             local_state_dict[fqn_with_prefix] = value | 
 |  | 
 |     if info.broadcast_from_rank0: | 
 |         device = None | 
 |         for key, value in local_state_dict.items(): | 
 |             if torch.is_tensor(value) and value.dim() > 0: | 
 |                 if device is None: | 
 |                     device = value.device | 
 |                 else: | 
 |                     assert device == value.device | 
 |         assert device is not None | 
 |         _broadcast_state_dict(state_dict, local_state_dict, device=device) | 
 |         for fqn, local_state in local_state_dict.items(): | 
 |             state_dict[fqn] = local_state | 
 |  | 
 |     with info.fsdp_context(): | 
 |         return cast( | 
 |             _IncompatibleKeys, | 
 |             _state_dict_fn(model, "load_state_dict")( | 
 |                 state_dict=state_dict, strict=info.strict | 
 |             ), | 
 |         ) | 
 |  | 
 |  | 
 | def _init_optim_state(optim: torch.optim.Optimizer) -> None: | 
 |     """ | 
 |     Initialize optim states by calling the step() with zero grads. | 
 |     """ | 
 |     if optim.state: | 
 |         # The optimizer state is initialized. | 
 |         return | 
 |  | 
 |     for param_group in optim.param_groups: | 
 |         for param in param_group[PARAMS]: | 
 |             if param.grad is not None: | 
 |                 raise RuntimeError( | 
 |                     "state_dict can only be used if the optimizer " | 
 |                     "states are initialized (usually after one step() with " | 
 |                     "gradients) or gradients are None. For the later case, " | 
 |                     "state_dict will fake the gradients as zero " | 
 |                     "to initialize the optimizer states. However, the " | 
 |                     "gradients are not None." | 
 |                 ) | 
 |             if param.requires_grad: | 
 |                 param.grad = torch.zeros_like(param) | 
 |  | 
 |     # Some optimizers will update parameters regardless of grads due to lr, so | 
 |     # make lr to zero when calling `step()`. | 
 |     lrs = [] | 
 |     for param_group in optim.param_groups: | 
 |         if "lr" in param_group: | 
 |             lrs.append(param_group["lr"]) | 
 |             param_group["lr"] = 0.0 | 
 |     optim.step(closure=None) | 
 |     # Whether to recover the "lr" should not matter too much as we will | 
 |     # restore checkpointing later. | 
 |     for param_group in optim.param_groups: | 
 |         if "lr" in param_group: | 
 |             param_group["lr"] = lrs.pop(0) | 
 |     optim.zero_grad(set_to_none=True) | 
 |  | 
 |  | 
 | def _get_optim_state_dict( | 
 |     model: nn.Module, | 
 |     optimizers: Tuple[torch.optim.Optimizer, ...], | 
 |     info: _StateDictInfo, | 
 | ) -> OptimizerStateType: | 
 |     if not info.handle_optim: | 
 |         return {} | 
 |  | 
 |     optim_state_dict: OptimizerStateType = {STATE: {}, PG: []} | 
 |     for optim in optimizers: | 
 |         _init_optim_state(optim) | 
 |         osd = _state_dict_fn(optim, "state_dict")() | 
 |         if info.fsdp_modules: | 
 |             with info.fsdp_context(): | 
 |                 osd = FSDP.optim_state_dict(model, optim, osd) | 
 |  | 
 |             # We need to specially handle FlatParameter FSDP as | 
 |             # FlatParameter FSDP converts the FQNs. | 
 |             # There are no easy ways to do this conversion systematically. | 
 |             # We can only use a string replacment without correctness check. | 
 |             if not osd: | 
 |                 continue | 
 |             for k in list(osd[STATE].keys()): | 
 |                 if "_orig_mod" in k: | 
 |                     osd[STATE][k.replace("_orig_mod.", "")] = osd[STATE].pop(k) | 
 |             for g in osd[PG]: | 
 |                 params = [k.replace("_orig_mod.", "") for k in g[PARAMS]] | 
 |                 g[PARAMS] = params | 
 |         else: | 
 |             params = list(chain.from_iterable(g[PARAMS] for g in optim.param_groups)) | 
 |             param_pid_mapping = dict(zip(params, range(len(params)))) | 
 |             fqn_pid_mapping = {} | 
 |             for key, param in model.named_parameters(): | 
 |                 fqns = _get_fqns(model, key) | 
 |                 assert len(fqns) == 1 | 
 |                 fqn = next(iter(fqns)) | 
 |                 if param not in param_pid_mapping: | 
 |                     continue | 
 |                 pid = param_pid_mapping[param] | 
 |                 fqn_pid_mapping[fqn] = pid | 
 |                 fqn_pid_mapping[pid] = fqn | 
 |  | 
 |             for key in list(osd[STATE].keys()): | 
 |                 fqn = fqn_pid_mapping[key] | 
 |                 osd[STATE][fqn] = osd[STATE].pop(key) | 
 |  | 
 |             for group in osd[PG]: | 
 |                 group[PARAMS] = [fqn_pid_mapping[pid] for pid in group[PARAMS]] | 
 |  | 
 |         if not osd: | 
 |             continue | 
 |  | 
 |         cast(DictValueType, optim_state_dict[STATE]).update(osd[STATE]) | 
 |         cast(ListDictValueType, optim_state_dict[PG]).extend(osd[PG]) | 
 |  | 
 |     if info.full_state_dict: | 
 |         ranks_only = tuple() if not info.cpu_offload else (0,) | 
 |         return _gather_state_dict( | 
 |             optim_state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only | 
 |         ) | 
 |     elif info.cpu_offload: | 
 |         return _offload_state_dict_to_cpu(optim_state_dict) | 
 |     else: | 
 |         return optim_state_dict | 
 |  | 
 |  | 
 | def _split_optim_state_dict( | 
 |     model: nn.Module, | 
 |     optim: torch.optim.Optimizer, | 
 |     optim_state_dict: OptimizerStateType, | 
 |     info: _StateDictInfo, | 
 | ) -> OptimizerStateType: | 
 |     """ | 
 |     Extract the corresponding optim state_dict from ``optim_state_dict`` for | 
 |     ``optim`` and return the result optim state_dict. | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the root model. | 
 |         optim (torch.optim.Optimizer): the optimizer. | 
 |         optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that | 
 |             contains the optim state_dict of ``optim``. | 
 |         info (_StateDictInfo): state dict information. | 
 |  | 
 |     Returns: | 
 |         The optim state_dict of ``optim``. | 
 |     """ | 
 |  | 
 |     state: DictValueType = {} | 
 |     pg_state: ListDictValueType = [] | 
 |     return_osd: OptimizerStateType = {STATE: state, PG: pg_state} | 
 |     pg_mapping: Dict[int, int] = {} | 
 |  | 
 |     for param_group in optim.param_groups: | 
 |         pg_state.append({PARAMS: []}) | 
 |         for param in param_group[PARAMS]: | 
 |             for fqn in info.fqn_param_mapping[param]: | 
 |                 params = pg_state[-1][PARAMS] | 
 |                 assert isinstance(params, list) | 
 |                 params.append(fqn) | 
 |                 if param.requires_grad: | 
 |                     state[fqn] = cast(DictValueType, optim_state_dict[STATE])[fqn] | 
 |                 for loaded_param_group in cast(ListDictValueType, optim_state_dict[PG]): | 
 |                     params = loaded_param_group[PARAMS] | 
 |                     assert isinstance(params, list) | 
 |                     if fqn in params: | 
 |                         pg_mapping[id(loaded_param_group)] = len(return_osd[PG]) - 1 | 
 |  | 
 |     for param_group in cast(ListDictValueType, optim_state_dict[PG]): | 
 |         idx = pg_mapping.get(id(param_group), -1) | 
 |         if idx == -1: | 
 |             continue | 
 |         for key, value in param_group.items(): | 
 |             if key == PARAMS: | 
 |                 continue | 
 |             # TODO: check if value is the same if exists. | 
 |             pg_state[idx][key] = value | 
 |  | 
 |     return return_osd | 
 |  | 
 |  | 
 | def _load_optim_state_dict( | 
 |     model: nn.Module, | 
 |     optimizers: Tuple[torch.optim.Optimizer, ...], | 
 |     state_dict: OptimizerStateType, | 
 |     info: _StateDictInfo, | 
 | ) -> None: | 
 |     if not info.handle_optim: | 
 |         return | 
 |  | 
 |     for optim in optimizers: | 
 |         _init_optim_state(optim) | 
 |         if state_dict: | 
 |             optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info) | 
 |         else: | 
 |             optim_state_dict = {} | 
 |         if info.fsdp_modules: | 
 |             # We need to specially handle FlatParameter FSDP as | 
 |             # FlatParameter FSDP converts the FQNs. | 
 |             for original_fqn, _ in model.named_parameters(): | 
 |                 fqns = _get_fqns(model, original_fqn) | 
 |                 fqns_with_compiler = _get_fqns( | 
 |                     model, original_fqn, skip_compiler_prefix=False | 
 |                 ) | 
 |                 if fqns == fqns_with_compiler: | 
 |                     continue | 
 |  | 
 |                 assert len(fqns) == 1 | 
 |                 fqn = fqns.pop() | 
 |                 fqn_with_compiler = fqns_with_compiler.pop() | 
 |                 for g in optim_state_dict[PG]: | 
 |                     val = cast(Dict[str, Any], g) | 
 |                     params = [ | 
 |                         key.replace(fqn, fqn_with_compiler) for key in val[PARAMS] | 
 |                     ] | 
 |                     val[PARAMS] = params | 
 |                 osd_state = cast(DictValueType, optim_state_dict[STATE]) | 
 |                 for k in list(osd_state.keys()): | 
 |                     if fqn in k: | 
 |                         osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) | 
 |  | 
 |             with info.fsdp_context(): | 
 |                 optim_state_dict = FSDP.optim_state_dict_to_load( | 
 |                     model, optim, optim_state_dict | 
 |                 ) | 
 |         elif info.broadcast_from_rank0: | 
 |             info.full_state_dict = False | 
 |             local_state_dict = _get_optim_state_dict(model, (optim,), info) | 
 |             info.full_state_dict = True | 
 |             device = None | 
 |  | 
 |             def _device(t): | 
 |                 if t.dim() > 0: | 
 |                     nonlocal device | 
 |                     if device is None: | 
 |                         device = t.device | 
 |                     elif device != t.device: | 
 |                         raise ValueError("Device mismatch") | 
 |                 return t | 
 |  | 
 |             _ = tree_map_only(torch.Tensor, _device, local_state_dict) | 
 |             assert device is not None | 
 |             flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) | 
 |             flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) | 
 |             _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) | 
 |             optim_state_dict = _unflatten_state_dict( | 
 |                 flatten_local_osd, local_osd_mapping | 
 |             ) | 
 |  | 
 |         # Note that we do not have to convert the FQN back to param id here if | 
 |         # order in optim.param_groups[idx][PARAMS] is the same as the one in | 
 |         # optim_state_dict[PG][idx][PARAMS]. | 
 |         _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) | 
 |  | 
 |  | 
 | def get_model_state_dict( | 
 |     model: nn.Module, | 
 |     *, | 
 |     submodules: Optional[Set[nn.Module]] = None, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> Dict[str, ValueType]: | 
 |     """ | 
 |     Return the model state_dict of ``model``. | 
 |  | 
 |     See ``get_state_dict`` for the detail usage. | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         submodules: Optional[Set[nn.Module]]: only return the model parameters | 
 |             that belong to the submodules. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be returned. See | 
 |             `StateDictOptions` for the details. | 
 |  | 
 |     Returns: | 
 |         The state_dict for ``model``. | 
 |  | 
 |     :rtype: typing.Dict[str, ValueType] | 
 |     """ | 
 |     with gc_context(): | 
 |         info = _verify_options( | 
 |             model, | 
 |             tuple(), | 
 |             optim_only=False, | 
 |             submodules=submodules, | 
 |             options=options, | 
 |         ) | 
 |         model_state_dict = _get_model_state_dict(model, info) | 
 |         _verify_state_dict(model_state_dict, {}, info) | 
 |         return model_state_dict | 
 |  | 
 |  | 
 | def get_optimizer_state_dict( | 
 |     model: nn.Module, | 
 |     optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], | 
 |     *, | 
 |     submodules: Optional[Set[nn.Module]] = None, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> OptimizerStateType: | 
 |     """ | 
 |     Return the combined state_dict for optimizers. | 
 |  | 
 |     See ``get_state_dict`` for the detail usage. | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         optimizers (Union[None, Optimizer, Iterable[Optimizer]]): | 
 |             The optimizers that are used to optimize ``model``. | 
 |         submodules: Optional[Set[nn.Module]]: only return the model parameters | 
 |             that belong to the submodules. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be returned. See | 
 |             `StateDictOptions` for the details. | 
 |  | 
 |     Returns: | 
 |         The state_dict for ``optimizers``. | 
 |  | 
 |     :rtype: OptimizerStateType | 
 |     """ | 
 |     with gc_context(): | 
 |         optimizers = ( | 
 |             (optimizers,) | 
 |             if isinstance(optimizers, torch.optim.Optimizer) | 
 |             else tuple(optimizers) | 
 |         ) | 
 |         info = _verify_options( | 
 |             model, | 
 |             optimizers, | 
 |             optim_only=True, | 
 |             submodules=submodules, | 
 |             options=options, | 
 |         ) | 
 |         optim_state_dict = _get_optim_state_dict(model, optimizers, info) | 
 |         _verify_state_dict({}, optim_state_dict, info) | 
 |         return optim_state_dict | 
 |  | 
 |  | 
 | def get_state_dict( | 
 |     model: nn.Module, | 
 |     optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], | 
 |     *, | 
 |     submodules: Optional[Set[nn.Module]] = None, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> Tuple[Dict[str, ValueType], OptimizerStateType]: | 
 |     """ | 
 |     Return the model state_dict and optimizers state_dict. | 
 |  | 
 |     ``get_state_dict`` can process any module that is parallelized by PyTorch | 
 |     FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any | 
 |     combination of these parallelisms. The main functions of ``get_state_dict`` | 
 |     are: 1.) returning a model and optimizer state_dict that can be resharded | 
 |     with a different number of trainers and/or different parallelisms. | 
 |     2.) hiding the parallelism-specific state_dict APIs. Users don't have to call | 
 |     these APIs. | 
 |     3.) sanity checking the result state_dict. | 
 |  | 
 |     The keys of the result state dictionary are the canonical FQNs (Fully | 
 |     Qualified Names).  A canonical FQN refers to the FQN based on a parameter's | 
 |     position in an nn.Module hierarchy. More specifically, a canonical FQN to a | 
 |     parameter is the FQN returned by ``module.named_parameters()`` or | 
 |     ``module.named_buffers()`` when the module is not distributed by any | 
 |     parallelisms. Since the optimizer internally uses parameter IDs to represent | 
 |     a parameter, there will be a conversion from the parameter IDs to the | 
 |     canonical FQNs when calling this API. | 
 |  | 
 |     ``get_state_dict`` can also process a module that is not parallelized. In | 
 |     such a case, ``get_state_dict`` only performs one function -- converting the | 
 |     optimizer parameter IDs to the canonical FQNs. | 
 |  | 
 |     Example: | 
 |         >>> # xdoctest: +SKIP | 
 |         >>> import torch | 
 |         >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | 
 |         >>> from torch.nn.parallel import DistributedDataParallel as DDP | 
 |         >>> from torch.distributed.checkpoint.state_dict import get_state_dict | 
 |  | 
 |         >>> fsdp_model = FSDP(copy.deepcopy(model)) | 
 |         >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) | 
 |         >>> ddp_model = DDP(copy.deepcopy(model)) | 
 |         >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) | 
 |  | 
 |  | 
 |         >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) | 
 |         >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim) | 
 |  | 
 |         >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), | 
 |         >>> # the asserts will fail. | 
 |         >>> assert ddp_state_dict == fsdp_state_dict | 
 |         >>> assert ddp_optim_state == fsdp_optim_state_dict | 
 |  | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         optimizers (Union[None, Optimizer, Iterable[Optimizer]]): | 
 |             The optimizers that are used to optimize ``model``. | 
 |         submodules: Optional[Set[nn.Module]]: only return the model parameters | 
 |             that belong to the submodules. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be returned. See | 
 |             `StateDictOptions` for the details. | 
 |  | 
 |     Returns: | 
 |         ``Tuple`` that contain model state_dict and optimizer state_dict. | 
 |  | 
 |     :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] | 
 |     """ | 
 |  | 
 |     with gc_context(): | 
 |         optimizers = ( | 
 |             (optimizers,) | 
 |             if isinstance(optimizers, torch.optim.Optimizer) | 
 |             else tuple(optimizers) | 
 |         ) | 
 |         info = _verify_options( | 
 |             model, | 
 |             optimizers, | 
 |             optim_only=False, | 
 |             submodules=submodules, | 
 |             options=options, | 
 |         ) | 
 |         model_state_dict = _get_model_state_dict(model, info) | 
 |         optim_state_dict = _get_optim_state_dict(model, optimizers, info) | 
 |         _verify_state_dict(model_state_dict, optim_state_dict, info) | 
 |         return model_state_dict, optim_state_dict | 
 |  | 
 |  | 
 | def _unflatten_model_state_dict( | 
 |     model: nn.Module, | 
 |     state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]], | 
 | ) -> Dict[str, ValueType]: | 
 |     if not state_dict: | 
 |         return {} | 
 |  | 
 |     if isinstance(next(iter(state_dict.keys())), nn.Module): | 
 |         cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict) | 
 |         new_state_dict: Dict[str, ValueType] = {} | 
 |         for submodule, sub_state_dict in cast_state_dict.items(): | 
 |             for name, m in model.named_modules(): | 
 |                 if m != submodule: | 
 |                     continue | 
 |  | 
 |                 fqns = _get_fqns(model, name) | 
 |                 assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" | 
 |                 prefix = f"{next(iter(fqns))}." | 
 |                 new_state_dict.update( | 
 |                     {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} | 
 |                 ) | 
 |         return new_state_dict | 
 |     else: | 
 |         return cast(Dict[str, ValueType], state_dict) | 
 |  | 
 |  | 
 | def set_model_state_dict( | 
 |     model: nn.Module, | 
 |     model_state_dict: Dict[str, ValueType], | 
 |     *, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> _IncompatibleKeys: | 
 |     """Load the model state_dict. | 
 |  | 
 |     The counterpart of ``get_model_state_dict`` to set the state_dict to the | 
 |     model. See ``set_state_dict`` for the detail usage. | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         model_state_dict: (Dict[str, ValueType]): | 
 |            the model state_dict to load. If the key of the ``model_state_dict`` | 
 |            is nn.Module, the key is a submodule of ``model`` and the value should | 
 |            be the state_dict of the submodule. When loading the state_dict, | 
 |            the prefix of the submodule will be append to the state_dict. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be loaded. See | 
 |             `StateDictOptions` for the details. | 
 |  | 
 |     Returns: | 
 |         ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: | 
 |             * **missing_keys** is a list of str containing the missing keys | 
 |             * **unexpected_keys** is a list of str containing the unexpected keys | 
 |  | 
 |     :type model_state_dict: typing.Dict[str, ValueType] | 
 |     """ | 
 |     model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( | 
 |         model, model_state_dict | 
 |     ) | 
 |     with gc_context(): | 
 |         info = _verify_options(model, tuple(), optim_only=False, options=options) | 
 |  | 
 |         _verify_state_dict(model_state_dict, {}, info) | 
 |         return _load_model_state_dict(model, model_state_dict, info) | 
 |  | 
 |  | 
 | def set_optimizer_state_dict( | 
 |     model: nn.Module, | 
 |     optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], | 
 |     *, | 
 |     optim_state_dict: OptimizerStateType, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> None: | 
 |     """Load the optimizers state_dict. | 
 |  | 
 |     The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the | 
 |     optimizers. See ``set_state_dict`` for the detail usage. | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         optimizers (Union[Optimizer, Iterable[Optimizer]]): | 
 |             The optimizers that are used to optimize ``model``. | 
 |         optim_state_dict: OptimizerStateType: | 
 |             the optimizer state_dict to load. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be loaded. See | 
 |             `StateDictOptions` for the details. | 
 |  | 
 |     Returns: | 
 |         None | 
 |  | 
 |     :type optim_state_dict: typing.OptimizerStateType | 
 |     """ | 
 |     with gc_context(): | 
 |         optimizers = ( | 
 |             (optimizers,) | 
 |             if isinstance(optimizers, torch.optim.Optimizer) | 
 |             else tuple(optimizers) | 
 |         ) | 
 |         info = _verify_options(model, optimizers, optim_only=True, options=options) | 
 |  | 
 |         _verify_state_dict({}, optim_state_dict, info) | 
 |         _load_optim_state_dict(model, optimizers, optim_state_dict, info) | 
 |  | 
 |  | 
 | def set_state_dict( | 
 |     model: nn.Module, | 
 |     optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], | 
 |     *, | 
 |     model_state_dict: Dict[str, ValueType], | 
 |     optim_state_dict: OptimizerStateType, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> _IncompatibleKeys: | 
 |     """Load the model state_dict and optimizers state_dict. | 
 |  | 
 |     The counterpart of ``get_state_dict`` to set the state_dict to the model and | 
 |     optimizers.  The given ``model_state_dict`` and ``optim_state_dict`` do not | 
 |     have to be returned by ``get_state_dict`` but must meet the following | 
 |     requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, | 
 |     2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, | 
 |     3) optimizer state_dict cannot contain the parameter IDs; the keys should be | 
 |     the canonical FQNs. | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         optimizers (Union[Optimizer, Iterable[Optimizer]]): | 
 |             The optimizers that are used to optimize ``model``. | 
 |         model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): | 
 |            the model state_dict to load. If the key of the ``model_state_dict`` | 
 |            is nn.Module, the key is a submodule of ``model`` and the value should | 
 |            be the state_dict of the submodule. When loading the state_dict, | 
 |            the prefix of the submodule will be append to the state_dict. | 
 |         optim_state_dict: OptimizerStateType: | 
 |             the optimizer state_dict to load. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be loaded. See | 
 |             `StateDictOptions` for the details. | 
 |  | 
 |     Returns: | 
 |         ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: | 
 |             * **missing_keys** is a list of str containing the missing keys of the model state_dict. | 
 |             * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. | 
 |  | 
 |     :type model_state_dict: typing.Dict[str, ValueType] | 
 |     :type optim_state_dict: typing.OptimizerStateType | 
 |     """ | 
 |  | 
 |     model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( | 
 |         model, model_state_dict | 
 |     ) | 
 |     with gc_context(): | 
 |         optimizers = ( | 
 |             (optimizers,) | 
 |             if isinstance(optimizers, torch.optim.Optimizer) | 
 |             else tuple(optimizers) | 
 |         ) | 
 |         info = _verify_options( | 
 |             model, optimizers, optim_only=not model_state_dict, options=options | 
 |         ) | 
 |  | 
 |         _verify_state_dict(model_state_dict, optim_state_dict, info) | 
 |         _load_optim_state_dict(model, optimizers, optim_state_dict, info) | 
 |         return _load_model_state_dict(model, model_state_dict, info) | 
 |  | 
 |  | 
 | # TODO: correct the state_dict function signature. | 
 | # TODO: this API is not yet fully tested. Make it private | 
 | @no_type_check | 
 | def _patch_model_state_dict( | 
 |     model: nn.Module, | 
 |     *, | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> None: | 
 |     """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. | 
 |  | 
 |     Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to | 
 |     be a partial function to call ``get_state_dict`` and ``set_state_dict``. | 
 |  | 
 |     Example: | 
 |         from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | 
 |         from torch.distributed.checkpoint.state_dict import patch_model_state_dict | 
 |  | 
 |         model = fsdp(model) | 
 |         patch_model_state_dict(model) | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be loaded. See | 
 |             `StateDictOptions` for the details. | 
 |     Returns: | 
 |         None | 
 |     """ | 
 |  | 
 |     _state_dict_call = functools.partial( | 
 |         get_model_state_dict, | 
 |         model=model, | 
 |         options=options, | 
 |     ) | 
 |  | 
 |     def state_dict_call(): | 
 |         return _state_dict_call() | 
 |  | 
 |     model.state_dict = state_dict_call | 
 |  | 
 |     _load_state_dict_call = functools.partial( | 
 |         set_model_state_dict, | 
 |         model=model, | 
 |         options=options, | 
 |     ) | 
 |  | 
 |     def load_state_dict_call(state_dict: Dict[str, Any]): | 
 |         _load_state_dict_call(model_state_dict=state_dict) | 
 |  | 
 |     model.load_state_dict = load_state_dict_call | 
 |  | 
 |     _patched_state_dict.add(state_dict_call) | 
 |     _patched_state_dict.add(load_state_dict_call) | 
 |  | 
 |  | 
 | # TODO: correct the load_state_dict function signature. | 
 | # TODO: this API is not yet fully tested. Make it private | 
 | @no_type_check | 
 | def _patch_optimizer_state_dict( | 
 |     model: nn.Module, | 
 |     *, | 
 |     optimizers: Tuple[torch.optim.Optimizer, ...], | 
 |     options: Optional[StateDictOptions] = None, | 
 | ) -> None: | 
 |     """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. | 
 |  | 
 |     Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to | 
 |     be a partial function to call ``get_state_dict`` and ``set_state_dict``. | 
 |  | 
 |     Note that if there are multiple optimizers, all of the optimizers will be patched. | 
 |     So users only need to call one of the state_dict() to get the full result. | 
 |  | 
 |     Example: | 
 |         from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | 
 |         from torch.distributed.checkpoint.state_dict import patch_model_state_dict | 
 |  | 
 |         model = fsdp(model) | 
 |         patch_model_state_dict(model) | 
 |  | 
 |     Args: | 
 |         model (nn.Module): the nn.Module to the model. | 
 |         options (StateDictOptions): the options to control how | 
 |             model state_dict and optimizer state_dict should be loaded. See | 
 |             `StateDictOptions` for the details. | 
 |     Returns: | 
 |         None | 
 |     """ | 
 |  | 
 |     _state_dict_call = functools.partial( | 
 |         get_optimizer_state_dict, | 
 |         model=model, | 
 |         optimizers=optimizers, | 
 |         options=options, | 
 |     ) | 
 |  | 
 |     def state_dict_call(): | 
 |         return _state_dict_call() | 
 |  | 
 |     _load_state_dict_call = functools.partial( | 
 |         set_optimizer_state_dict, | 
 |         model=model, | 
 |         optimizers=optimizers, | 
 |         options=options, | 
 |     ) | 
 |  | 
 |     def load_state_dict_call(state_dict: Dict[str, Any]): | 
 |         _load_state_dict_call(optim_state_dict=state_dict) | 
 |  | 
 |     _patched_state_dict.add(state_dict_call) | 
 |     _patched_state_dict.add(load_state_dict_call) | 
 |     optimizers = ( | 
 |         (optimizers,) | 
 |         if isinstance(optimizers, torch.optim.Optimizer) | 
 |         else tuple(optimizers) | 
 |     ) | 
 |     for optim in optimizers: | 
 |         optim.state_dict = state_dict_call | 
 |         optim.load_state_dict = load_state_dict_call |