[FSDP] Implement local_state_dict and load_local_state_dict (#72469)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72469
1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`.
2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict().
ghstack-source-id: 149559985
Test Plan: CI
Reviewed By: rohan-varma
Differential Revision: D33919683
fbshipit-source-id: c9f1b43ce04da7db65c4aebf6ac2c7a0ac5e9de8
(cherry picked from commit 55fd6230c9656fdf30a70dcd8071d094d2e67022)
diff --git a/test/distributed/fsdp/test_flatten_params_wrapper.py b/test/distributed/fsdp/test_flatten_params_wrapper.py
index c4a7eb6..69c78ee 100644
--- a/test/distributed/fsdp/test_flatten_params_wrapper.py
+++ b/test/distributed/fsdp/test_flatten_params_wrapper.py
@@ -198,7 +198,7 @@
expected,
msg=f"{flat_p.shard_metadata()}, {expected}",
)
- self.assertEqual(flat_p._num_padded, kwargs["num_padded"])
+ self.assertEqual(flat_p.num_padded, kwargs["num_padded"])
_test(
kwargs={"start": -1, "end": -1, "num_padded": 0},
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
new file mode 100644
index 0000000..00776fe
--- /dev/null
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -0,0 +1,138 @@
+# Owner(s): ["oncall: distributed"]
+
+import sys
+from typing import Any, Dict
+
+import torch
+from torch import distributed as dist
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import StateDictType
+from torch.nn import Linear, Module
+from torch.nn.parallel import DistributedDataParallel
+from torch.optim import SGD
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_fsdp import (
+ FSDPTest,
+ get_full_params,
+)
+from torch.testing._internal.common_utils import (
+ instantiate_parametrized_tests,
+ parametrize,
+ run_tests,
+ TEST_WITH_DEV_DBG_ASAN,
+)
+
+
+if not dist.is_available():
+ print("Distributed not available, skipping tests", file=sys.stderr)
+ sys.exit(0)
+
+if TEST_WITH_DEV_DBG_ASAN:
+ print(
+ "Skip dev-asan as torch + multiprocessing spawn have known issues",
+ file=sys.stderr,
+ )
+ sys.exit(0)
+
+INNER_SHAPE = [4, 4]
+OUTER_SHAPE = [4, 5]
+
+
+class Model(Module):
+ def __init__(self, wrap_fsdp):
+ super().__init__()
+ self.inner = Linear(*INNER_SHAPE)
+ if wrap_fsdp:
+ self.inner = FSDP(self.inner)
+ self.outer = Linear(*OUTER_SHAPE)
+
+ def forward(self, x):
+ # Forward twice.
+ i = self.inner(x)
+ j = self.inner(x)
+ return self.outer(i + j)
+
+
+class TestFSDPStateDict(FSDPTest):
+ @property
+ def world_size(self):
+ return 2
+
+ def _initialize_model(self, wrap_fsdp: bool):
+ # keep everything deterministic for input data
+ torch.manual_seed(0)
+
+ model = Model(wrap_fsdp).cuda()
+ if wrap_fsdp:
+ model = FSDP(model)
+ else:
+ model = DistributedDataParallel(model, device_ids=[self.rank])
+ return model
+
+ @staticmethod
+ def _state_dict(model: Module, state_dict_type: str):
+ return getattr(model, state_dict_type)()
+
+ @staticmethod
+ def _load_state_dict(
+ model: Module, state_dict_type: str, state_dict: Dict[str, Any]
+ ):
+ getattr(model, f"load_{state_dict_type}")(state_dict)
+
+ def _dist_train(
+ self, wrap_fsdp: bool, state_dict_type: str = "", with_context: bool = False
+ ):
+ # TODO: Move this test to common_fsdp.
+ model = self._initialize_model(wrap_fsdp)
+ optim = SGD(model.parameters(), lr=0.1)
+
+ in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))
+ for _ in range(3):
+ out = model(in_data)
+ out.sum().backward()
+ optim.step()
+ optim.zero_grad()
+
+ if wrap_fsdp:
+ blank_model = FSDP(Model(True).cuda())
+ if with_context:
+ state_dict_type = {
+ "full_state_dict": StateDictType.FULL_STATE_DICT,
+ "local_state_dict": StateDictType.LOCAL_STATE_DICT,
+ "sharded_state_dict": StateDictType.SHARDED_STATE_DICT,
+ }[state_dict_type]
+ with model.state_dict_type(state_dict_type):
+ state_dict = model.state_dict()
+ with blank_model.state_dict_type(state_dict_type):
+ blank_model.load_state_dict(state_dict)
+ else:
+ state_dict = self._state_dict(model, state_dict_type)
+ self._load_state_dict(blank_model, state_dict_type, state_dict)
+ get_full_params(blank_model)
+ model = blank_model
+
+ return list(model.parameters())
+
+ @skip_if_lt_x_gpu(2)
+ @parametrize("state_dict_type", ["local_state_dict"])
+ def test_state_dict_save_load_flow(self, state_dict_type):
+ fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type)
+ fsdp_params_using_context = self._dist_train(
+ wrap_fsdp=True, state_dict_type=state_dict_type, with_context=True
+ )
+ ddp_params = self._dist_train(wrap_fsdp=False)
+ self.assertEqual(ddp_params, fsdp_params)
+ self.assertEqual(ddp_params, fsdp_params_using_context)
+
+ @skip_if_lt_x_gpu(2)
+ @parametrize("state_dict_type", ["local_state_dict"])
+ def test_fsdp_state_dict_keys(self, state_dict_type):
+ state_dict = self._state_dict(self._initialize_model(True), state_dict_type)
+ if state_dict_type == "local_state_dict":
+ self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys())
+
+
+instantiate_parametrized_tests(TestFSDPStateDict)
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/distributed/fsdp/test_fsdp_summon_full_params.py b/test/distributed/fsdp/test_fsdp_summon_full_params.py
index fbb2ef2..f6ec725 100644
--- a/test/distributed/fsdp/test_fsdp_summon_full_params.py
+++ b/test/distributed/fsdp/test_fsdp_summon_full_params.py
@@ -1,8 +1,8 @@
# Owner(s): ["oncall: distributed"]
import itertools
-from copy import deepcopy
import math
import sys
+from copy import deepcopy
import torch
import torch.nn as nn
@@ -35,11 +35,10 @@
)
sys.exit(0)
+
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer):
model = FSDP(
- nn.Sequential(
- FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
- )
+ nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))
).cuda(cls.rank)
# set the value
@@ -64,6 +63,7 @@
else:
cls.assertEqual(p.cpu()[0], cls.rank + 2)
+
class TestSummonFullParamsNoShard(FSDPTest):
@property
def world_size(self):
@@ -84,6 +84,7 @@
modify_outer,
)
+
class TestSummonFullParams(FSDPTest):
@property
def world_size(self):
@@ -105,10 +106,7 @@
@parametrize("modify_outer", [True, False])
def test_summon_full_param_writeback(self, writeback, cpu_offload, modify_outer):
return _run_test_summon_full_param_writeback(
- self,
- writeback,
- cpu_offload,
- modify_outer
+ self, writeback, cpu_offload, modify_outer
)
@skip_if_lt_x_gpu(2)
diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py
index d2c311d..7c1f0b3 100644
--- a/torch/distributed/fsdp/__init__.py
+++ b/torch/distributed/fsdp/__init__.py
@@ -1,3 +1,4 @@
from .flatten_params_wrapper import FlatParameter
from .fully_sharded_data_parallel import FullyShardedDataParallel
from .fully_sharded_data_parallel import CPUOffload
+from .fully_sharded_data_parallel import StateDictType
diff --git a/torch/distributed/fsdp/flatten_params_wrapper.py b/torch/distributed/fsdp/flatten_params_wrapper.py
index ef3af64..13be7bd 100644
--- a/torch/distributed/fsdp/flatten_params_wrapper.py
+++ b/torch/distributed/fsdp/flatten_params_wrapper.py
@@ -18,14 +18,60 @@
Optional,
Sequence,
Tuple,
+ TYPE_CHECKING,
+ Union,
)
import torch
import torch.nn as nn
from torch import Tensor
+from .utils import _replace_by_prefix
+
+if TYPE_CHECKING:
+ from collections import OrderedDict # noqa: F401
+
ParamOffset = Tuple[int, int]
SharedParamInfo = Tuple[str, str, nn.Module, str, nn.Module, str]
+FLAT_PARAM = "flat_param"
+FPW_MODULE = "_fpw_module"
+
+
+def _post_state_dict_hook(
+ module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
+) -> "OrderedDict[str, Tensor]":
+ """
+ _post_state_dict_hook() is called after the state_dict() is executed
+ and before returning the state_dict to the users.
+ This API post-processes the keys of the state_dict to remove the
+ FlattenParamsWrapper internal prefix.
+ """
+ # Move everything from FPW_MODULE up one level.
+ _replace_by_prefix(state_dict, prefix + f"{FPW_MODULE}.", prefix)
+ return state_dict
+
+
+def _pre_load_state_dict_hook(
+ state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"],
+ prefix: str,
+ *args: Any,
+) -> None:
+ """
+ _post_state_dict_hook() is called before the _load_from_state_dict() is
+ This API pre-processes the keys of the state_dict to add the
+ FlattenParamsWrapper internal prefix
+ """
+ # Push everything down to FPW_MODULE level.
+ _replace_by_prefix(state_dict, prefix, prefix + f"{FPW_MODULE}.")
+ # The flat_param_* keys actually needs to move one level up.
+ flat_param_key = prefix + f"{FPW_MODULE}.{FLAT_PARAM}"
+ for k in list(state_dict.keys()):
+ if k.startswith(flat_param_key):
+ last_part = k.split(".")[-1]
+ assert last_part.startswith(
+ FLAT_PARAM
+ ), f"Expected key to contain flat_param, but key name is {k}"
+ _replace_by_prefix(state_dict, k, prefix + last_part)
class ParamInfo(NamedTuple):
@@ -98,10 +144,13 @@
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
self._is_sharded = False
self._param_numels = [p.numel() for p in params]
- assert self.numel() <= sum(self._param_numels), (
+ # The total element numbers. This is equal to the summation of the
+ # ``numel()`` of all the parameters.
+ self.full_numel = sum(self._param_numels)
+ assert self.numel() <= self.full_numel, (
"Parameter numbers mismatched. "
f"The number of elements in FlatParameter: {self.numel()} vs. "
- f"the number of elements in original parameters: {sum(self._param_numels)}."
+ f"the number of elements in original parameters: {self.full_numel}."
)
# The shapes of each individual parameter.
self._param_shapes = [p.size() for p in params]
@@ -124,7 +173,7 @@
(0, numel) for numel in self._param_numels
]
# The number of padding elements.
- self._num_padded = 0
+ self.num_padded = 0
def shard_by_offsets(self, start: int, end: int, num_padded: int) -> None:
assert self._is_sharded
@@ -133,8 +182,8 @@
f"Shard the flatten parameter with an invalid offset pair {(start, end)}."
)
_shard_size = end - start + 1
- self._num_padded = num_padded
- if self._num_padded > _shard_size:
+ self.num_padded = num_padded
+ if self.num_padded > _shard_size:
raise ValueError("The number of padding is larger than the shard size.")
self._sharded_param_offsets.clear()
@@ -163,13 +212,13 @@
) -> Iterator[Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
- assert self.data.numel() <= sum(
- self._param_numels
- ), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
+ assert (
+ self.data.numel() <= self.full_numel
+ ), f"Incorrect internal state {self.data.numel()} vs. {self.full_numel}"
data = external_data if external_data is not None else self
- if data.numel() != sum(self._param_numels):
+ if data.numel() != self.full_numel:
raise ValueError(
- f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
+ f"Incorrect numel of supplied data: got {data.numel()} but expected {self.full_numel}"
)
return (
t.view(s)
@@ -252,6 +301,15 @@
self._orig_flat_param: List[Optional[FlatParameter]] = [None]
self._flatten_params()
+ # Sanity check for the string constants.
+ assert getattr(self, FPW_MODULE) is self._fpw_module
+ assert getattr(self, FLAT_PARAM) is self.flat_param
+
+ # Register hook to be called after state_dict() to remove the
+ # "_fpw_module." prefix and before load_state_dict() to add it back.
+ self._register_state_dict_hook(_post_state_dict_hook)
+ self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
+
@property
def module(self) -> Any:
"""Support _fsdp_wrapped_module.module in case we are immitating DDP, which has .module
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index fe61684..baea027 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -9,9 +9,10 @@
Any,
Callable,
Dict,
- Generator,
List,
Optional,
+ Generator,
+ NamedTuple,
Set,
Tuple,
Union,
@@ -24,17 +25,28 @@
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributed import ProcessGroup
+from torch.distributed._sharded_tensor import (
+ init_from_local_shards,
+ Shard,
+ ShardedTensor,
+)
from torch.distributed.distributed_c10d import _get_default_group
from torch.nn.parameter import Parameter
-from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper
-from .utils import _apply_to_tensors
+from .flatten_params_wrapper import FlatParameter, FlattenParamsWrapper, FLAT_PARAM
+from .utils import (
+ _apply_to_tensors,
+ _replace_by_prefix,
+)
from .wrap import _recursive_wrap
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
+FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
+
+
@dataclass
class CPUOffload:
"""
@@ -98,6 +110,31 @@
SUMMON_FULL_PARAMS = auto()
+class StateDictType(Enum):
+ """
+ This enum indicates that which type of ``state_dict`` the FSDP module is
+ currently processing (returning or loading).
+ The default value should be FULL_STATE_DICT to comply the PyTorch convention.
+ ..note::
+ FSDP currently supports three types of ``state_dict``:
+ 1. ``state_dict/load_state_dict`: this pair of APIs return and load
+ the non-sharded, unflattened parameters. The semantics is the
+ same as using DDP.
+ 2. ``local_state_dict/load_local_state``: this pair of APIs return
+ and load local sharded, flattened parameters. The values returned
+ by ``local_state_dict`` can be directly used by FSDP and is only
+ meaningful to FSDP (because parameters are flattened).
+ 3. ``sharded_state_dict/load_sharded_state_dict``: this pair of APIs
+ return and load sharded, unflattened parameters. The ``state_dict``
+ return by ``sharded_state_dict`` can be used by all other parallel
+ schemes (resharding may be required).
+ """
+
+ FULL_STATE_DICT = auto()
+ LOCAL_STATE_DICT = auto()
+ SHARDED_STATE_DICT = auto()
+
+
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
@@ -244,6 +281,7 @@
self._fsdp_wrapped_module: FlattenParamsWrapper = FlattenParamsWrapper(
module, param_list=params
)
+ assert getattr(self, FSDP_WRAPPED_MODULE) is self._fsdp_wrapped_module
del module # free original module in case it helps garbage collection
if self._fsdp_wrapped_module.flat_param is not None:
self.params = [self._fsdp_wrapped_module.flat_param]
@@ -268,6 +306,29 @@
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState_.IDLE
+ self._state_dict_type = StateDictType.FULL_STATE_DICT
+
+ # FSDP currently provides three different state_dicts. The actual
+ # state_dict that will be saved/loaded is decided by
+ # self._state_dict_type. And the main logic of each state_dict is
+ # implemented in the hook. Therefore, for each hook (post-save and
+ # pre-load), there is a dispatcher dictionary to dispatch the execution
+ # flow to the correct implementation.
+ self._register_state_dict_hook(self._post_state_dict_hook)
+ self._post_state_dict_hook_fn = {
+ StateDictType.FULL_STATE_DICT: self._full_post_state_dict_hook,
+ StateDictType.LOCAL_STATE_DICT: self._local_post_state_dict_hook,
+ StateDictType.SHARDED_STATE_DICT: self._sharded_post_state_dict_hook,
+ }
+ self._register_load_state_dict_pre_hook(
+ self._pre_load_state_dict_hook, with_module=True
+ )
+ self._pre_load_state_dict_hook_fn = {
+ StateDictType.FULL_STATE_DICT: self._full_pre_load_state_dict_hook,
+ StateDictType.LOCAL_STATE_DICT: self._local_pre_load_state_dict_hook,
+ StateDictType.SHARDED_STATE_DICT: self._sharded_pre_load_state_dict_hook,
+ }
+
# Flag to guard against preparing gradients multiple times per backward pass.
self._pre_backward_hook_has_run = False
# Used for prefetching all gather full params in post backward hook
@@ -679,6 +740,227 @@
else:
return False
+ @contextlib.contextmanager
+ def state_dict_type(self, state_dict_type: StateDictType) -> Generator:
+ """
+ A context manager to set the state_dict_type of this FSDP module and
+ its descendant FSDP modules.
+ .. note:: This API should be called for only the root FSDP module.
+ .. note:: The default state_dict_type is StateDictTyp.FULL_STATE_DICT.
+
+ Args:
+ state_dict_type (StateDictType): the desired state_dict_type to set.
+ """
+ self._lazy_init()
+ if not self._is_root:
+ raise RuntimeError(
+ f"state_dict_type context manager can only be called from the root FSDP module. {self._is_root}"
+ )
+ prev_state_dict_type = self._state_dict_type
+ for module in self.modules():
+ if isinstance(module, FullyShardedDataParallel):
+ if module._state_dict_type != prev_state_dict_type:
+ raise RuntimeError(
+ "All FSDP module should the same state_dict_type."
+ )
+ module._state_dict_type = state_dict_type
+ try:
+ yield
+ finally:
+ for module in self.modules():
+ if isinstance(module, FullyShardedDataParallel):
+ module._state_dict_type = prev_state_dict_type
+
+ def _full_post_state_dict_hook(
+ self,
+ state_dict: "OrderedDict[str, torch.Tensor]",
+ prefix: str,
+ ) -> "OrderedDict[str, torch.Tensor]":
+ return state_dict
+
+ def _local_post_state_dict_hook(
+ self,
+ state_dict: "OrderedDict[str, torch.Tensor]",
+ prefix: str,
+ ) -> "OrderedDict[str, torch.Tensor]":
+ """
+ This hook create a ShardedTensor from the local flat_param and replace
+ the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
+ will happen. The underlying storage is the same.
+ """
+ _replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix)
+ # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
+ # value as the flat_param but it is a pure Tensor because
+ # nn.Module.state_dict() will detach the parameter. Therefore, we need
+ # to get flat_param from the FlattenParamsWrapper to get the metadata.
+ flat_param = getattr(self.module, FLAT_PARAM, None)
+ assert (
+ flat_param is not None
+ ), "flat_param cannot be None when doing local_state_dict."
+
+ # Construct a ShardedTensor from the flat_param.
+ full_numel = flat_param.full_numel
+ shard_offset = flat_param.numel() * self.rank
+ valid_data_size = flat_param.numel() - flat_param.num_padded
+ if valid_data_size > 0 and flat_param.num_padded > 0:
+ flat_param = flat_param.narrow(0, 0, valid_data_size)
+ local_shards = [
+ Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
+ ]
+ state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
+ local_shards, full_numel, process_group=self.process_group
+ ) # type: ignore[assignment]
+
+ return state_dict
+
+ def _sharded_post_state_dict_hook(
+ self,
+ state_dict: "OrderedDict[str, torch.Tensor]",
+ prefix: str,
+ ) -> "OrderedDict[str, torch.Tensor]":
+ raise NotImplementedError("Will be implemented in the next PRs.")
+
+ @staticmethod
+ def _post_state_dict_hook(
+ module: nn.Module,
+ state_dict: "OrderedDict[str, torch.Tensor]",
+ prefix: str,
+ *args: Any,
+ ) -> "OrderedDict[str, torch.Tensor]":
+ """
+ _post_state_dict_hook() is called after the state_dict() of this
+ FSDP module is executed. ``self._state_dict_type`` is used to decide
+ what postprocessing will be done.
+ """
+ self = cast(FullyShardedDataParallel, module)
+ return self._post_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ """
+ The entry point of all three FSDP state_dict APIs.
+ ``self._state_dict_type`` decides which code path to execute.
+
+ .. warning:: This needs to be called on all ranks, since synchronization
+ primitives may be used.
+ """
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ if self._state_dict_type == StateDictType.FULL_STATE_DICT:
+ return super().state_dict(destination, prefix, keep_vars)
+ elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
+ assert getattr(self.module, FLAT_PARAM, None) is not None
+ assert isinstance(self.module.flat_param, FlatParameter)
+ return super().state_dict(destination, prefix, keep_vars)
+ elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
+ raise NotImplementedError("Will be implemented in the next PRs.")
+ else:
+ raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
+
+ def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
+ """
+ Returns the local state of the module. Parameters are flattened and
+ sharded, so the resulting state_dict can only be loaded after the module
+ has been wrapped with FSDP.
+ """
+ with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+ return self.state_dict(*args, **kwargs)
+
+ def _full_pre_load_state_dict_hook(
+ self,
+ state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+ prefix: str,
+ ) -> None:
+ return
+
+ def _local_pre_load_state_dict_hook(
+ self,
+ state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+ prefix: str,
+ ) -> None:
+ """
+ This hook finds the local flat_param for this FSDP module from the
+ state_dict. The flat_param should be a ShardedTensor. This hook converts
+ the ShardedTensor to a tensor. No copy happen unless padding is required.
+ """
+ _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_WRAPPED_MODULE}.")
+ key = f"{prefix}{FSDP_WRAPPED_MODULE}.{FLAT_PARAM}"
+ load_tensor = state_dict[key]
+ assert isinstance(
+ load_tensor, ShardedTensor
+ ), "Tensors in local_state_dict should be ShardedTensor."
+
+ # Convert the ShardedTensor to a Tensor.
+ shards = load_tensor.local_shards()
+ assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
+ load_tensor = cast(torch.Tensor, shards[0].tensor)
+
+ # Get the metada of the flat_param to decide whether to pad the loaded
+ # tensor.
+ flat_param = self.module.flat_param
+ assert flat_param is not None
+ if flat_param.num_padded not in (0, flat_param.numel()):
+ assert load_tensor.numel() < flat_param.numel(), (
+ f"Local shard size = {flat_param.numel()} and the tensor in "
+ f"the state_dict is {load_tensor.numel()}."
+ )
+ load_tensor = F.pad(load_tensor, [0, flat_param.num_padded])
+ state_dict[key] = load_tensor
+
+ def _sharded_pre_load_state_dict_hook(
+ self,
+ state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+ prefix: str,
+ ) -> None:
+ raise NotImplementedError("Will be implemented in the next PRs.")
+
+ @staticmethod
+ def _pre_load_state_dict_hook(
+ module: nn.Module,
+ state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+ prefix: str,
+ *args: Any,
+ ) -> None:
+ """
+ ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()``
+ is called. ``self._state_dict_type`` is used to decide what preprocessing
+ will be done.
+ """
+ self = cast(FullyShardedDataParallel, module)
+ self._pre_load_state_dict_hook_fn[self._state_dict_type](state_dict, prefix)
+
+ def load_state_dict(
+ self,
+ state_dict: "OrderedDict[str, torch.Tensor]",
+ strict: bool = True,
+ ) -> NamedTuple:
+ """
+ The entry point of all three FSDP load_state_dict APIs.
+ ``self._state_dict_type`` decides which code path to execute.
+
+ .. warning:: This needs to be called on all ranks, since synchronization
+ primitives may be used.
+ """
+ torch.cuda.synchronize()
+ if self._state_dict_type == StateDictType.FULL_STATE_DICT:
+ return super().load_state_dict(state_dict, strict)
+ elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
+ return super().load_state_dict(state_dict, strict)
+ elif self._state_dict_type == StateDictType.SHARDED_STATE_DICT:
+ raise NotImplementedError("Will be implemented in the next PRs.")
+ else:
+ raise ValueError(f"Unknown StateDictType {self._state_dict_type}.")
+
+ def load_local_state_dict(
+ self,
+ state_dict: "OrderedDict[str, torch.Tensor]",
+ strict: bool = True,
+ ) -> NamedTuple:
+ """
+ Load states from a flatten, sharded state dictionary.
+ """
+ with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+ return self.load_state_dict(state_dict, strict)
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
self._lazy_init()
@@ -1110,6 +1392,7 @@
"""
Gather all shards of params.
"""
+ self._lazy_init()
def update_p_data(output_tensor: torch.Tensor) -> None:
"""
@@ -1246,8 +1529,7 @@
until the eventual sync.
"""
self._lazy_init()
- assert self._is_root, \
- "`no_sync()` on inner FSDP instances is not supported"
+ assert self._is_root, "`no_sync()` on inner FSDP instances is not supported"
self._assert_state(TrainingState_.IDLE)
old_flags = []
for m in self.modules():
@@ -1258,9 +1540,10 @@
yield
finally:
for m, old_flag in old_flags:
- assert not m._require_backward_grad_sync, \
- "`_require_backward_grad_sync` was incorrectly set to " \
+ assert not m._require_backward_grad_sync, (
+ "`_require_backward_grad_sync` was incorrectly set to "
"`True` while in the `no_sync()` context manager"
+ )
m._require_backward_grad_sync = old_flag
diff --git a/torch/distributed/fsdp/utils.py b/torch/distributed/fsdp/utils.py
index 3b54967..2b64ab9 100644
--- a/torch/distributed/fsdp/utils.py
+++ b/torch/distributed/fsdp/utils.py
@@ -1,7 +1,9 @@
-from typing import Dict, List, Tuple, Union, Any, Callable, Set
+from typing import Dict, List, Tuple, Union, Any, Callable, Set, TYPE_CHECKING
import torch
+if TYPE_CHECKING:
+ from collections import OrderedDict # noqa: F401
"""Useful functions to deal with tensor types with other python container types."""
@@ -22,3 +24,27 @@
return x
return apply(container)
+
+
+def _replace_by_prefix(
+ state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"],
+ old_prefix: str,
+ new_prefix: str,
+) -> None:
+ """
+ Replace all keys that match a given old_prefix with a new_prefix (in-place).
+
+ Usage::
+
+ state_dict = {"layer.xyz": torch.tensor(1)}
+ replace_by_prefix_(state_dict, "layer.", "module.layer.")
+ assert state_dict == {"module.layer.xyz": torch.tensor(1)}
+ """
+ if old_prefix == new_prefix:
+ raise ValueError("old_prefix and new_prefix must be distinct")
+ for key in list(state_dict.keys()):
+ if not key.startswith(old_prefix):
+ continue
+ new_key = new_prefix + key[len(old_prefix) :]
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]