[FSDP] Provide a utility API to allow users easily to set state_dict_type (#73716)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73716
fsdp_state_dict_type() traverses the modules and calls state_dict_type when it is necessary, eliminating the burden from the users.
ghstack-source-id: 150775052
Test Plan: CI
Reviewed By: rohan-varma
Differential Revision: D34532321
fbshipit-source-id: 339dcb14a692d5d83c266892644393464d77db40
(cherry picked from commit bb4634c77d6b46644a7daad005ff52738753221a)
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
index 86734a1..b9e43ab 100644
--- a/test/distributed/fsdp/test_fsdp_state_dict.py
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -185,7 +185,7 @@
except KeyError:
raise ValueError(f"No state_dict type for {state_dict_type}")
- with model.state_dict_type(enum_val):
+ with FSDP.state_dict_type(model, enum_val):
return model.state_dict()
@staticmethod
@@ -197,7 +197,7 @@
except KeyError:
raise ValueError(f"No state_dict for {state_dict_type}")
- with model.state_dict_type(enum_val):
+ with FSDP.state_dict_type(model, enum_val):
return model.load_state_dict(state_dict)
def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""):
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 85b415d..f32fa7b 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -766,12 +766,15 @@
else:
return False
+ @staticmethod
@contextlib.contextmanager
- def state_dict_type(self, state_dict_type: StateDictType) -> Generator:
+ def state_dict_type(module: nn.Module, state_dict_type: StateDictType) -> Generator:
"""
- A context manager to set the state_dict_type of this FSDP module and
- its descendant FSDP modules.
- .. note:: This API should be called for only the root FSDP module.
+ A context manager to set the state_dict_type of all the descendant FSDP
+ modules of the target module. The target module does not have to be a FSDP
+ module. If the target module is a FSDP module, its state_dict_type will
+ also be changed.
+ .. note:: This API should be called for only the top-level (root) module.
.. note:: The default state_dict_type is StateDictTyp.FULL_STATE_DICT.
.. note:: This API enables users to transparently use the conventional
``state_dict`` API to take model checkpoints in cases where the root
@@ -785,13 +788,11 @@
Args:
state_dict_type (StateDictType): the desired state_dict_type to set.
"""
- if not self.check_is_root():
- raise RuntimeError(
- "state_dict_type context manager can only be called from the root FSDP module."
- )
- prev_state_dict_type = self._state_dict_type
- for module in self.fsdp_modules(self):
- if module._state_dict_type != prev_state_dict_type:
+ prev_state_dict_type = None
+ for module in FullyShardedDataParallel.fsdp_modules(module):
+ if prev_state_dict_type is None:
+ prev_state_dict_type = module._state_dict_type
+ if prev_state_dict_type != module._state_dict_type:
raise RuntimeError(
"All FSDP module should the same state_dict_type."
)
@@ -799,7 +800,8 @@
try:
yield
finally:
- for module in self.fsdp_modules(self):
+ assert prev_state_dict_type is not None # Avoid mypy warning
+ for module in FullyShardedDataParallel.fsdp_modules(module):
module._state_dict_type = prev_state_dict_type
def _full_post_state_dict_hook(
@@ -891,7 +893,7 @@
``state_dict`` on every rank, which could result in OOM if the model
cannot fit on a single GPU. As a result, :func:`state_dict_type` API is
available to configure between `state_dict` implementations. User can
- thus use `with self.state_dict_type(StateDictType.LOCAL_STATE_DICT)`
+ thus use `with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)`
context manager to perform a local checkpoint that will store only local
shards of the module. Currently, the only supported implementations are
``StateDictType.LOCAL_STATE_DICT`` and ``StateDictType.FULL_STATE_DICT``
@@ -948,7 +950,7 @@
sharded, so the resulting state_dict can only be loaded after the module
has been wrapped with FSDP.
"""
- with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+ with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT):
return self.state_dict(*args, **kwargs)
def _full_pre_load_state_dict_hook(
@@ -1027,7 +1029,7 @@
FSDP to load the full parameter context on each rank which could result
in GPU OOM. As a result, :func:`state_dict_type` API is available to
configure between `load_state_dict` implementations. User can thus use
- ``with self.state_dict_type(StateDictType.LOCAL_STATE_DICT)`` context
+ ``with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)`` context
manager to load a local state dict checkpoint that will restore only
local shards of the module. Currently, the only supported
implementations are ``StateDictType.LOCAL_STATE_DICT`` and
@@ -1079,7 +1081,7 @@
"""
Load states from a flatten, sharded state dictionary.
"""
- with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+ with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT):
return self.load_state_dict(state_dict, *args)
def forward(self, *args: Any, **kwargs: Any) -> Any: