[FSDP] Consolidate FSDP state_dict offload_to_cpu settings (#86211)

Consolidate FSDP state_dict offload_to_cpu settings. All state_dict_types now
have offload_to_cpu options.

Differential Revision: [D40065969](https://our.internmc.facebook.com/intern/diff/D40065969/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86211
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
index 514dc8c..af56ee9 100644
--- a/test/distributed/fsdp/test_fsdp_state_dict.py
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -13,11 +13,13 @@
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
     checkpoint_wrapper,
 )
-from torch.distributed.fsdp import CPUOffload, FullStateDictConfig
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.distributed.fsdp import (
+    CPUOffload,
+    FullStateDictConfig,
     LocalStateDictConfig,
     MixedPrecision,
+    ShardedStateDictConfig,
     StateDictType,
 )
 from torch.distributed.fsdp._shard_utils import _gather_state_dict
@@ -186,8 +188,16 @@
                 rank0_only=state_dict_rank0_and_offload,
                 offload_to_cpu=state_dict_rank0_and_offload,
             )
+        elif state_dict_type == "local_state_dict":
+            config = LocalStateDictConfig(
+                offload_to_cpu=state_dict_rank0_and_offload,
+            )
+        elif state_dict_type == "sharded_state_dict":
+            config = ShardedStateDictConfig(
+                offload_to_cpu=state_dict_rank0_and_offload,
+            )
         else:
-            config = None
+            raise ValueError("Unsupported state_dict_type")
         return FSDP.state_dict_type(model, _state_dict_type, config)
 
     def _validate_state_dict_contents(
diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py
index e2b39c9..324a344 100644
--- a/torch/distributed/fsdp/__init__.py
+++ b/torch/distributed/fsdp/__init__.py
@@ -7,6 +7,7 @@
     LocalStateDictConfig,
     MixedPrecision,
     OptimStateKeyType,
+    ShardedStateDictConfig,
     ShardingStrategy,
     StateDictType,
 )
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 38a2b50..76d662b 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -307,6 +307,7 @@
     LOCAL_STATE_DICT = auto()
     SHARDED_STATE_DICT = auto()
 
+
 @dataclass
 class StateDictConfig:
     """
@@ -315,7 +316,8 @@
     order to configure settings for the particular type of ``state_dict``
     implementation FSDP will use.
     """
-    pass
+    offload_to_cpu: bool = False
+
 
 @dataclass
 class FullStateDictConfig(StateDictConfig):
@@ -345,23 +347,26 @@
         >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
         >>> # After this point, all ranks have FSDP model with loaded checkpoint.
     """
-    offload_to_cpu: bool = False
     rank0_only: bool = False
 
+
 @dataclass
 class LocalStateDictConfig(StateDictConfig):
     pass
 
+
 @dataclass
 class ShardedStateDictConfig(StateDictConfig):
     pass
 
+
 _state_dict_type_to_config = {
     StateDictType.FULL_STATE_DICT: FullStateDictConfig,
     StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
     StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
 }
 
+
 class OptimStateKeyType(Enum):
     PARAM_NAME = auto()
     PARAM_ID = auto()
@@ -2327,10 +2332,12 @@
         local_shards = [
             Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
         ]
-        state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
+        sharded_tensor = init_from_local_shards(
             local_shards, full_numel, process_group=self.process_group
         )  # type: ignore[assignment]
-
+        if self._state_dict_config.offload_to_cpu:
+            sharded_tensor = sharded_tensor.cpu()
+        state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
         return state_dict
 
     @torch.no_grad()
@@ -2355,13 +2362,16 @@
             for fqn, _, _ in self._param_fqns:
                 # Create a ShardedTensor for the unflattened, non-sharded parameter.
                 param = functools.reduce(getattr, fqn.split("."), self.module)
-                state_dict[f"{prefix}{fqn}"] = _ext_chunk_tensor(
+                sharded_tensor = _ext_chunk_tensor(
                     tensor=param,
                     rank=self.rank,
                     world_size=self.world_size,
                     num_devices_per_node=torch.cuda.device_count(),
                     pg=self.process_group
-                )  # type: ignore[assignment]
+                )
+                if self._state_dict_config.offload_to_cpu:
+                    sharded_tensor = sharded_tensor.cpu()
+                state_dict[f"{prefix}{fqn}"] = sharded_tensor
         # For `use_orig_params=True`, the `FlatParameter` is not registered, so
         # there is no entry in the state dict for it to pop.
         if not self._use_orig_params: