[FSDP] Consolidate FSDP state_dict offload_to_cpu settings (#86211)
Consolidate FSDP state_dict offload_to_cpu settings. All state_dict_types now
have offload_to_cpu options.
Differential Revision: [D40065969](https://our.internmc.facebook.com/intern/diff/D40065969/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86211
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
index 514dc8c..af56ee9 100644
--- a/test/distributed/fsdp/test_fsdp_state_dict.py
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -13,11 +13,13 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
-from torch.distributed.fsdp import CPUOffload, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
+ CPUOffload,
+ FullStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
+ ShardedStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp._shard_utils import _gather_state_dict
@@ -186,8 +188,16 @@
rank0_only=state_dict_rank0_and_offload,
offload_to_cpu=state_dict_rank0_and_offload,
)
+ elif state_dict_type == "local_state_dict":
+ config = LocalStateDictConfig(
+ offload_to_cpu=state_dict_rank0_and_offload,
+ )
+ elif state_dict_type == "sharded_state_dict":
+ config = ShardedStateDictConfig(
+ offload_to_cpu=state_dict_rank0_and_offload,
+ )
else:
- config = None
+ raise ValueError("Unsupported state_dict_type")
return FSDP.state_dict_type(model, _state_dict_type, config)
def _validate_state_dict_contents(
diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py
index e2b39c9..324a344 100644
--- a/torch/distributed/fsdp/__init__.py
+++ b/torch/distributed/fsdp/__init__.py
@@ -7,6 +7,7 @@
LocalStateDictConfig,
MixedPrecision,
OptimStateKeyType,
+ ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 38a2b50..76d662b 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -307,6 +307,7 @@
LOCAL_STATE_DICT = auto()
SHARDED_STATE_DICT = auto()
+
@dataclass
class StateDictConfig:
"""
@@ -315,7 +316,8 @@
order to configure settings for the particular type of ``state_dict``
implementation FSDP will use.
"""
- pass
+ offload_to_cpu: bool = False
+
@dataclass
class FullStateDictConfig(StateDictConfig):
@@ -345,23 +347,26 @@
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
"""
- offload_to_cpu: bool = False
rank0_only: bool = False
+
@dataclass
class LocalStateDictConfig(StateDictConfig):
pass
+
@dataclass
class ShardedStateDictConfig(StateDictConfig):
pass
+
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
}
+
class OptimStateKeyType(Enum):
PARAM_NAME = auto()
PARAM_ID = auto()
@@ -2327,10 +2332,12 @@
local_shards = [
Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
]
- state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
+ sharded_tensor = init_from_local_shards(
local_shards, full_numel, process_group=self.process_group
) # type: ignore[assignment]
-
+ if self._state_dict_config.offload_to_cpu:
+ sharded_tensor = sharded_tensor.cpu()
+ state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
return state_dict
@torch.no_grad()
@@ -2355,13 +2362,16 @@
for fqn, _, _ in self._param_fqns:
# Create a ShardedTensor for the unflattened, non-sharded parameter.
param = functools.reduce(getattr, fqn.split("."), self.module)
- state_dict[f"{prefix}{fqn}"] = _ext_chunk_tensor(
+ sharded_tensor = _ext_chunk_tensor(
tensor=param,
rank=self.rank,
world_size=self.world_size,
num_devices_per_node=torch.cuda.device_count(),
pg=self.process_group
- ) # type: ignore[assignment]
+ )
+ if self._state_dict_config.offload_to_cpu:
+ sharded_tensor = sharded_tensor.cpu()
+ state_dict[f"{prefix}{fqn}"] = sharded_tensor
# For `use_orig_params=True`, the `FlatParameter` is not registered, so
# there is no entry in the state dict for it to pop.
if not self._use_orig_params: