[FSDP] Implement local_state_dict and load_local_state_dict (#72469)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72469

1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`.
2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict().
ghstack-source-id: 149559985

Test Plan: CI

Reviewed By: rohan-varma

Differential Revision: D33919683

fbshipit-source-id: c9f1b43ce04da7db65c4aebf6ac2c7a0ac5e9de8
(cherry picked from commit 55fd6230c9656fdf30a70dcd8071d094d2e67022)
diff --git a/test/distributed/fsdp/test_flatten_params_wrapper.py b/test/distributed/fsdp/test_flatten_params_wrapper.py
index c4a7eb6..69c78ee 100644
--- a/test/distributed/fsdp/test_flatten_params_wrapper.py
+++ b/test/distributed/fsdp/test_flatten_params_wrapper.py
@@ -198,7 +198,7 @@
                     expected,
                     msg=f"{flat_p.shard_metadata()}, {expected}",
                 )
-                self.assertEqual(flat_p._num_padded, kwargs["num_padded"])
+                self.assertEqual(flat_p.num_padded, kwargs["num_padded"])
 
         _test(
             kwargs={"start": -1, "end": -1, "num_padded": 0},
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
new file mode 100644
index 0000000..00776fe
--- /dev/null
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -0,0 +1,138 @@
+# Owner(s): ["oncall: distributed"]
+
+import sys
+from typing import Any, Dict
+
+import torch
+from torch import distributed as dist
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import StateDictType
+from torch.nn import Linear, Module
+from torch.nn.parallel import DistributedDataParallel
+from torch.optim import SGD
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_fsdp import (
+    FSDPTest,
+    get_full_params,
+)
+from torch.testing._internal.common_utils import (
+    instantiate_parametrized_tests,
+    parametrize,
+    run_tests,
+    TEST_WITH_DEV_DBG_ASAN,
+)
+
+
+if not dist.is_available():
+    print("Distributed not available, skipping tests", file=sys.stderr)
+    sys.exit(0)
+
+if TEST_WITH_DEV_DBG_ASAN:
+    print(
+        "Skip dev-asan as torch + multiprocessing spawn have known issues",
+        file=sys.stderr,
+    )
+    sys.exit(0)
+
+INNER_SHAPE = [4, 4]
+OUTER_SHAPE = [4, 5]
+
+
+class Model(Module):
+    def __init__(self, wrap_fsdp):
+        super().__init__()
+        self.inner = Linear(*INNER_SHAPE)
+        if wrap_fsdp:
+            self.inner = FSDP(self.inner)
+        self.outer = Linear(*OUTER_SHAPE)
+
+    def forward(self, x):
+        # Forward twice.
+        i = self.inner(x)
+        j = self.inner(x)
+        return self.outer(i + j)
+
+
+class TestFSDPStateDict(FSDPTest):
+    @property
+    def world_size(self):
+        return 2
+
+    def _initialize_model(self, wrap_fsdp: bool):
+        # keep everything deterministic for input data
+        torch.manual_seed(0)
+
+        model = Model(wrap_fsdp).cuda()
+        if wrap_fsdp:
+            model = FSDP(model)
+        else:
+            model = DistributedDataParallel(model, device_ids=[self.rank])
+        return model
+
+    @staticmethod
+    def _state_dict(model: Module, state_dict_type: str):
+        return getattr(model, state_dict_type)()
+
+    @staticmethod
+    def _load_state_dict(
+        model: Module, state_dict_type: str, state_dict: Dict[str, Any]
+    ):
+        getattr(model, f"load_{state_dict_type}")(state_dict)
+
+    def _dist_train(
+        self, wrap_fsdp: bool, state_dict_type: str = "", with_context: bool = False
+    ):
+        # TODO: Move this test to common_fsdp.
+        model = self._initialize_model(wrap_fsdp)
+        optim = SGD(model.parameters(), lr=0.1)
+
+        in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
+        for _ in range(3):
+            out = model(in_data)
+            out.sum().backward()
+            optim.step()
+            optim.zero_grad()
+
+        if wrap_fsdp:
+            blank_model = FSDP(Model(True).cuda())
+            if with_context:
+                state_dict_type = {
+                    "full_state_dict": StateDictType.FULL_STATE_DICT,
+                    "local_state_dict": StateDictType.LOCAL_STATE_DICT,
+                    "sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
+                }[state_dict_type]
+                with model.state_dict_type(state_dict_type):
+                    state_dict = model.state_dict()
+                with blank_model.state_dict_type(state_dict_type):
+                    blank_model.load_state_dict(state_dict)
+            else:
+                state_dict = self._state_dict(model, state_dict_type)
+                self._load_state_dict(blank_model, state_dict_type, state_dict)
+            get_full_params(blank_model)
+            model = blank_model
+
+        return list(model.parameters())
+
+    @skip_if_lt_x_gpu(2)
+    @parametrize("state_dict_type", ["local_state_dict"])
+    def test_state_dict_save_load_flow(self, state_dict_type):
+        fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type)
+        fsdp_params_using_context = self._dist_train(
+            wrap_fsdp=True, state_dict_type=state_dict_type, with_context=True
+        )
+        ddp_params = self._dist_train(wrap_fsdp=False)
+        self.assertEqual(ddp_params, fsdp_params)
+        self.assertEqual(ddp_params, fsdp_params_using_context)
+
+    @skip_if_lt_x_gpu(2)
+    @parametrize("state_dict_type", ["local_state_dict"])
+    def test_fsdp_state_dict_keys(self, state_dict_type):
+        state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
+        if state_dict_type == "local_state_dict":
+            self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys())
+
+
+instantiate_parametrized_tests(TestFSDPStateDict)
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/distributed/fsdp/test_fsdp_summon_full_params.py b/test/distributed/fsdp/test_fsdp_summon_full_params.py
index fbb2ef2..f6ec725 100644
--- a/test/distributed/fsdp/test_fsdp_summon_full_params.py
+++ b/test/distributed/fsdp/test_fsdp_summon_full_params.py
@@ -1,8 +1,8 @@
 # Owner(s): ["oncall: distributed"]
 import itertools
-from copy import deepcopy
 import math
 import sys
+from copy import deepcopy
 
 import torch
 import torch.nn as nn
@@ -35,11 +35,10 @@
     )
     sys.exit(0)
 
+
 def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer):
     model = FSDP(
-        nn.Sequential(
-            FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
-        )
+        nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))
     ).cuda(cls.rank)
 
     # set the value
@@ -64,6 +63,7 @@
     else:
         cls.assertEqual(p.cpu()[0], cls.rank + 2)
 
+
 class TestSummonFullParamsNoShard(FSDPTest):
     @property
     def world_size(self):
@@ -84,6 +84,7 @@
             modify_outer,
         )
 
+
 class TestSummonFullParams(FSDPTest):
     @property
     def world_size(self):
@@ -105,10 +106,7 @@
     @parametrize("modify_outer", [True, False])
     def test_summon_full_param_writeback(self, writeback, cpu_offload, modify_outer):
         return _run_test_summon_full_param_writeback(
-            self,
-            writeback,
-            cpu_offload,
-            modify_outer
+            self, writeback, cpu_offload, modify_outer
         )
 
     @skip_if_lt_x_gpu(2)
diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py
index d2c311d..7c1f0b3 100644
--- a/torch/distributed/fsdp/__init__.py
+++ b/torch/distributed/fsdp/__init__.py
@@ -1,3 +1,4 @@
 from .flatten_params_wrapper import FlatParameter
 from .fully_sharded_data_parallel import FullyShardedDataParallel
 from .fully_sharded_data_parallel import CPUOffload
+from .fully_sharded_data_parallel import StateDictType
diff --git a/torch/distributed/fsdp/flatten_params_wrapper.py b/torch/distributed/fsdp/flatten_params_wrapper.py
index ef3af64..13be7bd 100644
--- a/torch/distributed/fsdp/flatten_params_wrapper.py
+++ b/torch/distributed/fsdp/flatten_params_wrapper.py
@@ -18,14 +18,60 @@
     Optional,
     Sequence,
     Tuple,
+    TYPE_CHECKING,
+    Union,
 )
 
 import torch
 import torch.nn as nn
 from torch import Tensor
 
+from .utils import _replace_by_prefix
+
+if TYPE_CHECKING:
+    from collections import OrderedDict  # noqa: F401
+
 ParamOffset = Tuple[int, int]
 SharedParamInfo = Tuple[str, str, nn.Module, str, nn.Module, str]
+FLAT_PARAM = "flat_param"
+FPW_MODULE = "_fpw_module"
+
+
+def _post_state_dict_hook(
+    module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
+) -> "OrderedDict[str, Tensor]":
+    """
+    _post_state_dict_hook() is called after the state_dict() is executed
+    and before returning the state_dict to the users.
+    This API post-processes the keys of the state_dict to remove the
+    FlattenParamsWrapper internal prefix.
+    """
+    # Move everything from FPW_MODULE up one level.
+    _replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
+    return state_dict
+
+
+def _pre_load_state_dict_hook(
+    state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"],
+    prefix: str,
+    *args: Any,
+) -> None:
+    """
+    _post_state_dict_hook() is called before the _load_from_state_dict() is
+    This API pre-processes the keys of the state_dict to add the
+    FlattenParamsWrapper internal prefix
+    """
+    # Push everything down to FPW_MODULE level.
+    _replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
+    # The flat_param_* keys actually needs to move one level up.
+    flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
+    for k in list(state_dict.keys()):
+        if k.startswith(flat_param_key):
+            last_part = k.split(".")[-1]
+            assert last_part.startswith(
+                FLAT_PARAM
+            ), f"Expected key to contain flat_param, but key name is {k}"
+            _replace_by_prefix(state_dict, k, prefix + last_part)
 
 
 class ParamInfo(NamedTuple):
@@ -98,10 +144,13 @@
     def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
         self._is_sharded = False
         self._param_numels = [p.numel() for p in params]
-        assert self.numel() <= sum(self._param_numels), (
+        # The total element numbers. This is equal to the summation of the
+        # ``numel()`` of all the parameters.
+        self.full_numel = sum(self._param_numels)
+        assert self.numel() <= self.full_numel, (
             "Parameter numbers mismatched. "
             f"The number of elements in FlatParameter: {self.numel()} vs. "
-            f"the number of elements in original parameters: {sum(self._param_numels)}."
+            f"the number of elements in original parameters: {self.full_numel}."
         )
         # The shapes of each individual parameter.
         self._param_shapes = [p.size() for p in params]
@@ -124,7 +173,7 @@
             (0, numel) for numel in self._param_numels
         ]
         # The number of padding elements.
-        self._num_padded = 0
+        self.num_padded = 0
 
     def shard_by_offsets(self, start: int, end: int, num_padded: int) -> None:
         assert self._is_sharded
@@ -133,8 +182,8 @@
                 f"Shard the flatten parameter with an invalid offset pair {(start, end)}."
             )
         _shard_size = end - start + 1
-        self._num_padded = num_padded
-        if self._num_padded > _shard_size:
+        self.num_padded = num_padded
+        if self.num_padded > _shard_size:
             raise ValueError("The number of padding is larger than the shard size.")
         self._sharded_param_offsets.clear()
 
@@ -163,13 +212,13 @@
     ) -> Iterator[Tensor]:
         """Return a generator of views that map to the original parameters."""
         # Note, self.data could be sharded, so its numel is <= to the sum.
-        assert self.data.numel() <= sum(
-            self._param_numels
-        ), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
+        assert (
+            self.data.numel() <= self.full_numel
+        ), f"Incorrect internal state {self.data.numel()} vs. {self.full_numel}"
         data = external_data if external_data is not None else self
-        if data.numel() != sum(self._param_numels):
+        if data.numel() != self.full_numel:
             raise ValueError(
-                f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
+                f"Incorrect numel of supplied data: got {data.numel()} but expected {self.full_numel}"
             )
         return (
             t.view(s)
@@ -252,6 +301,15 @@
         self._orig_flat_param: List[Optional[FlatParameter]] = [None]
         self._flatten_params()
 
+        # Sanity check for the string constants.
+        assert getattr(self, FPW_MODULE) is self._fpw_module
+        assert getattr(self, FLAT_PARAM) is self.flat_param
+
+        # Register hook to be called after state_dict() to remove the
+        # "_fpw_module." prefix and before load_state_dict() to add it back.
+        self._register_state_dict_hook(_post_state_dict_hook)
+        self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
+
     @property
     def module(self) -> Any:
         """Support _fsdp_wrapped_module.module in case we are immitating DDP, which has .module
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index fe61684..baea027 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -9,9 +9,10 @@
     Any,
     Callable,
     Dict,
-    Generator,
     List,
     Optional,
+    Generator,
+    NamedTuple,
     Set,
     Tuple,
     Union,
@@ -24,17 +25,28 @@
 import torch.nn.functional as F
 from torch.autograd import Variable
 from torch.distributed import ProcessGroup
+from torch.distributed._sharded_tensor import (
+    init_from_local_shards,
+    Shard,
+    ShardedTensor,
+)
 from torch.distributed.distributed_c10d import _get_default_group
 from torch.nn.parameter import Parameter
 
-from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper
-from .utils import _apply_to_tensors
+from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper, FLAT_PARAM
+from .utils import (
+    _apply_to_tensors,
+    _replace_by_prefix,
+)
 from .wrap import _recursive_wrap
 
 if TYPE_CHECKING:
     from collections import OrderedDict  # noqa: F401
 
 
+FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
+
+
 @dataclass
 class CPUOffload:
     """
@@ -98,6 +110,31 @@
     SUMMON_FULL_PARAMS = auto()
 
 
+class StateDictType(Enum):
+    """
+    This enum indicates that which type of ``state_dict`` the FSDP module is
+    currently processing (returning or loading).
+    The default value should be FULL_STATE_DICT to comply the PyTorch convention.
+    ..note::
+        FSDP currently supports three types of ``state_dict``:
+            1. ``state_dict/load_state_dict`: this pair of APIs return and load
+               the non-sharded, unflattened parameters. The semantics is the
+               same as using DDP.
+            2. ``local_state_dict/load_local_state``: this pair of APIs return
+               and load local sharded, flattened parameters. The values returned
+               by ``local_state_dict`` can be directly used by FSDP and is only
+               meaningful to FSDP (because parameters are flattened).
+            3. ``sharded_state_dict/load_sharded_state_dict``: this pair of APIs
+               return and load sharded, unflattened parameters. The ``state_dict``
+               return by ``sharded_state_dict`` can be used by all other parallel
+               schemes (resharding may be required).
+    """
+
+    FULL_STATE_DICT = auto()
+    LOCAL_STATE_DICT = auto()
+    SHARDED_STATE_DICT = auto()
+
+
 class FullyShardedDataParallel(nn.Module):
     """
     A wrapper for sharding Module parameters across data parallel workers. This
@@ -244,6 +281,7 @@
         self._fsdp_wrapped_module: FlattenParamsWrapper = FlattenParamsWrapper(
             module, param_list=params
         )
+        assert getattr(self, FSDP_WRAPPED_MODULE) is self._fsdp_wrapped_module
         del module  # free original module in case it helps garbage collection
         if self._fsdp_wrapped_module.flat_param is not None:
             self.params = [self._fsdp_wrapped_module.flat_param]
@@ -268,6 +306,29 @@
         # Enum to indicate if we're in the forward/backward pass, idle, etc.
         self.training_state = TrainingState_.IDLE
 
+        self._state_dict_type = StateDictType.FULL_STATE_DICT
+
+        # FSDP currently provides three different state_dicts. The actual
+        # state_dict that will be saved/loaded is decided by
+        # self._state_dict_type. And the main logic of each state_dict is
+        # implemented in the hook. Therefore, for each hook (post-save and
+        # pre-load), there is a dispatcher dictionary to dispatch the execution
+        # flow to the correct implementation.
+        self._register_state_dict_hook(self._post_state_dict_hook)
+        self._post_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: self._full_post_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: self._local_post_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: self._sharded_post_state_dict_hook,
+        }
+        self._register_load_state_dict_pre_hook(
+            self._pre_load_state_dict_hook, with_module=True
+        )
+        self._pre_load_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: self._full_pre_load_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: self._local_pre_load_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: self._sharded_pre_load_state_dict_hook,
+        }
+
         # Flag to guard against preparing gradients multiple times per backward pass.
         self._pre_backward_hook_has_run = False
         # Used for prefetching all gather full params in post backward hook
@@ -679,6 +740,227 @@
         else:
             return False
 
+    @contextlib.contextmanager
+    def state_dict_type(self, state_dict_type: StateDictType) -> Generator:
+        """
+        A context manager to set the state_dict_type of this FSDP module and
+        its descendant FSDP modules.
+        .. note:: This API should be called for only the root FSDP module.
+        .. note:: The default state_dict_type is StateDictTyp.FULL_STATE_DICT.
+
+        Args:
+            state_dict_type (StateDictType): the desired state_dict_type to set.
+        """
+        self._lazy_init()
+        if not self._is_root:
+            raise RuntimeError(
+                f"state_dict_type context manager can only be called from the root FSDP module.  {self._is_root}"
+            )
+        prev_state_dict_type = self._state_dict_type
+        for module in self.modules():
+            if isinstance(module, FullyShardedDataParallel):
+                if module._state_dict_type != prev_state_dict_type:
+                    raise RuntimeError(
+                        "All FSDP module should the same state_dict_type."
+                    )
+                module._state_dict_type = state_dict_type
+        try:
+            yield
+        finally:
+            for module in self.modules():
+                if isinstance(module, FullyShardedDataParallel):
+                    module._state_dict_type = prev_state_dict_type
+
+    def _full_post_state_dict_hook(
+        self,
+        state_dict: "OrderedDict[str, torch.Tensor]",
+        prefix: str,
+    ) -> "OrderedDict[str, torch.Tensor]":
+        return state_dict
+
+    def _local_post_state_dict_hook(
+        self,
+        state_dict: "OrderedDict[str, torch.Tensor]",
+        prefix: str,
+    ) -> "OrderedDict[str, torch.Tensor]":
+        """
+        This hook create a ShardedTensor from the local flat_param and replace
+        the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
+        will happen. The underlying storage is the same.
+        """
+        _replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix)
+        # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
+        # value as the flat_param but it is a pure Tensor because
+        # nn.Module.state_dict() will detach the parameter. Therefore, we need
+        # to get flat_param from the FlattenParamsWrapper to get the metadata.
+        flat_param = getattr(self.module, FLAT_PARAM, None)
+        assert (
+            flat_param is not None
+        ), "flat_param cannot be None when doing local_state_dict."
+
+        # Construct a ShardedTensor from the flat_param.
+        full_numel = flat_param.full_numel
+        shard_offset = flat_param.numel() * self.rank
+        valid_data_size = flat_param.numel() - flat_param.num_padded
+        if valid_data_size > 0 and flat_param.num_padded > 0:
+            flat_param = flat_param.narrow(0, 0, valid_data_size)
+        local_shards = [
+            Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
+        ]
+        state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
+            local_shards, full_numel, process_group=self.process_group
+        )  # type: ignore[assignment]
+
+        return state_dict
+
+    def _sharded_post_state_dict_hook(
+        self,
+        state_dict: "OrderedDict[str, torch.Tensor]",
+        prefix: str,
+    ) -> "OrderedDict[str, torch.Tensor]":
+        raise NotImplementedError("Will be implemented in the next PRs.")
+
+    @staticmethod
+    def _post_state_dict_hook(
+        module: nn.Module,
+        state_dict: "OrderedDict[str, torch.Tensor]",
+        prefix: str,
+        *args: Any,
+    ) -> "OrderedDict[str, torch.Tensor]":
+        """
+        _post_state_dict_hook() is called after the state_dict() of this
+        FSDP module is executed. ``self._state_dict_type`` is used to decide
+        what postprocessing will be done.
+        """
+        self = cast(FullyShardedDataParallel, module)
+        return self._post_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
+
+    def state_dict(self, destination=None, prefix="", keep_vars=False):
+        """
+        The entry point of all three FSDP state_dict APIs.
+        ``self._state_dict_type`` decides which code path to execute.
+
+        .. warning:: This needs to be called on all ranks, since synchronization
+            primitives may be used.
+        """
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+        if self._state_dict_type == StateDictType.FULL_STATE_DICT:
+            return super().state_dict(destination, prefix, keep_vars)
+        elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
+            assert getattr(self.module, FLAT_PARAM, None) is not None
+            assert isinstance(self.module.flat_param, FlatParameter)
+            return super().state_dict(destination, prefix, keep_vars)
+        elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
+            raise NotImplementedError("Will be implemented in the next PRs.")
+        else:
+            raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
+
+    def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
+        """
+        Returns the local state of the module. Parameters are flattened and
+        sharded, so the resulting state_dict can only be loaded after the module
+        has been wrapped with FSDP.
+        """
+        with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+            return self.state_dict(*args, **kwargs)
+
+    def _full_pre_load_state_dict_hook(
+        self,
+        state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+        prefix: str,
+    ) -> None:
+        return
+
+    def _local_pre_load_state_dict_hook(
+        self,
+        state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+        prefix: str,
+    ) -> None:
+        """
+        This hook finds the local flat_param for this FSDP module from the
+        state_dict. The flat_param should be a ShardedTensor. This hook converts
+        the ShardedTensor to a tensor. No copy happen unless padding is required.
+        """
+        _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_WRAPPED_MODULE}.")
+        key = f"{prefix}{FSDP_WRAPPED_MODULE}.{FLAT_PARAM}"
+        load_tensor = state_dict[key]
+        assert isinstance(
+            load_tensor, ShardedTensor
+        ), "Tensors in local_state_dict should be ShardedTensor."
+
+        # Convert the ShardedTensor to a Tensor.
+        shards = load_tensor.local_shards()
+        assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
+        load_tensor = cast(torch.Tensor, shards[0].tensor)
+
+        # Get the metada of the flat_param to decide whether to pad the loaded
+        # tensor.
+        flat_param = self.module.flat_param
+        assert flat_param is not None
+        if flat_param.num_padded not in (0, flat_param.numel()):
+            assert load_tensor.numel() < flat_param.numel(), (
+                f"Local shard size = {flat_param.numel()} and the tensor in "
+                f"the state_dict is {load_tensor.numel()}."
+            )
+            load_tensor = F.pad(load_tensor, [0, flat_param.num_padded])
+        state_dict[key] = load_tensor
+
+    def _sharded_pre_load_state_dict_hook(
+        self,
+        state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+        prefix: str,
+    ) -> None:
+        raise NotImplementedError("Will be implemented in the next PRs.")
+
+    @staticmethod
+    def _pre_load_state_dict_hook(
+        module: nn.Module,
+        state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+        prefix: str,
+        *args: Any,
+    ) -> None:
+        """
+        ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()``
+        is called. ``self._state_dict_type`` is used to decide what preprocessing
+        will be done.
+        """
+        self = cast(FullyShardedDataParallel, module)
+        self._pre_load_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
+
+    def load_state_dict(
+        self,
+        state_dict: "OrderedDict[str, torch.Tensor]",
+        strict: bool = True,
+    ) -> NamedTuple:
+        """
+        The entry point of all three FSDP load_state_dict APIs.
+        ``self._state_dict_type`` decides which code path to execute.
+
+        .. warning:: This needs to be called on all ranks, since synchronization
+            primitives may be used.
+        """
+        torch.cuda.synchronize()
+        if self._state_dict_type == StateDictType.FULL_STATE_DICT:
+            return super().load_state_dict(state_dict, strict)
+        elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
+            return super().load_state_dict(state_dict, strict)
+        elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
+            raise NotImplementedError("Will be implemented in the next PRs.")
+        else:
+            raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
+
+    def load_local_state_dict(
+        self,
+        state_dict: "OrderedDict[str, torch.Tensor]",
+        strict: bool = True,
+    ) -> NamedTuple:
+        """
+        Load states from a flatten, sharded state dictionary.
+        """
+        with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+            return self.load_state_dict(state_dict, strict)
+
     def forward(self, *args: Any, **kwargs: Any) -> Any:
         self._lazy_init()
 
@@ -1110,6 +1392,7 @@
         """
         Gather all shards of params.
         """
+        self._lazy_init()
 
         def update_p_data(output_tensor: torch.Tensor) -> None:
             """
@@ -1246,8 +1529,7 @@
             until the eventual sync.
         """
         self._lazy_init()
-        assert self._is_root, \
-            "`no_sync()` on inner FSDP instances is not supported"
+        assert self._is_root, "`no_sync()` on inner FSDP instances is not supported"
         self._assert_state(TrainingState_.IDLE)
         old_flags = []
         for m in self.modules():
@@ -1258,9 +1540,10 @@
             yield
         finally:
             for m, old_flag in old_flags:
-                assert not m._require_backward_grad_sync, \
-                    "`_require_backward_grad_sync` was incorrectly set to " \
+                assert not m._require_backward_grad_sync, (
+                    "`_require_backward_grad_sync` was incorrectly set to "
                     "`True` while in the `no_sync()` context manager"
+                )
                 m._require_backward_grad_sync = old_flag
 
 
diff --git a/torch/distributed/fsdp/utils.py b/torch/distributed/fsdp/utils.py
index 3b54967..2b64ab9 100644
--- a/torch/distributed/fsdp/utils.py
+++ b/torch/distributed/fsdp/utils.py
@@ -1,7 +1,9 @@
-from typing import Dict, List, Tuple, Union, Any, Callable, Set
+from typing import Dict, List, Tuple, Union, Any, Callable, Set, TYPE_CHECKING
 
 import torch
 
+if TYPE_CHECKING:
+    from collections import OrderedDict  # noqa: F401
 
 """Useful functions to deal with tensor types with other python container types."""
 
@@ -22,3 +24,27 @@
             return x
 
     return apply(container)
+
+
+def _replace_by_prefix(
+    state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+    old_prefix: str,
+    new_prefix: str,
+) -> None:
+    """
+    Replace all keys that match a given old_prefix with a new_prefix (in-place).
+
+    Usage::
+
+        state_dict = {"layer.xyz": torch.tensor(1)}
+        replace_by_prefix_(state_dict, "layer.", "module.layer.")
+        assert state_dict == {"module.layer.xyz": torch.tensor(1)}
+    """
+    if old_prefix == new_prefix:
+        raise ValueError("old_prefix and new_prefix must be distinct")
+    for key in list(state_dict.keys()):
+        if not key.startswith(old_prefix):
+            continue
+        new_key = new_prefix + key[len(old_prefix) :]
+        state_dict[new_key] = state_dict[key]
+        del state_dict[key]