| import contextlib |
| from dataclasses import dataclass |
| from enum import auto, Enum |
| from itertools import accumulate, chain |
| from typing import ( |
| cast, |
| Dict, |
| Generator, |
| Iterator, |
| List, |
| NamedTuple, |
| Optional, |
| Sequence, |
| Set, |
| Tuple, |
| Union, |
| ) |
| |
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| |
| from ._utils import _alloc_storage, _free_storage, _set_fsdp_flattened, p_assert |
| |
| __all__ = [ |
| "FlatParameter", |
| "FlatParamHandle", |
| "FlatParamShardMetadata", |
| "ParamInfo", |
| "SharedParamInfo", |
| "HandleConfig", |
| "HandleShardingStrategy", |
| "HandleTrainingState", |
| ] |
| |
| |
| class ParamInfo(NamedTuple): |
| """Information for an original module parameter.""" |
| |
| param_name: str # unprefixed |
| module: nn.Module |
| module_name: str |
| |
| |
| class SharedParamInfo(NamedTuple): |
| """ |
| Additional information for a shared parameter. |
| |
| For each shared parameter, we designate one module and its parameter |
| variable to be the primary owner, determined as the first one encountered |
| in the parameter walk. These are prefixed with "prim". The primary module |
| and parameter do not have their own :class:`SharedParamInfo` instance. |
| """ |
| |
| param_name: str # unprefixed |
| module: nn.Module |
| module_name: str |
| prim_param_name: str # unprefixed |
| prim_module: nn.Module |
| prim_module_name: str |
| |
| |
| class FlatParamShardMetadata(NamedTuple): |
| """ |
| This holds metadata specific to this rank's shard of the flattened |
| parameter. |
| |
| Attributes: |
| param_names (Tuple[str, ...]): Prefixed parameter names of this rank's |
| shard of the parameters; see :class:`FlatParameter`. |
| param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's |
| shard of the parameters; see :class:`FlatParameter`. |
| param_numels (Tuple[int, ...]): Parameter numels of this rank's shard |
| of the parameters; see :class:`FlatParameter`. |
| param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in |
| units of numels) giving this rank's part of each flattened |
| original module parameter. |
| """ |
| |
| param_names: Tuple[str, ...] |
| param_shapes: Tuple[torch.Size, ...] |
| param_numels: Tuple[int, ...] |
| param_offsets: Tuple[Tuple[int, int], ...] |
| |
| |
| # TODO (awgu): Prefix these with "Handle" for now to avoid circular imports and |
| # inadvertent misuses; coalesce with those in fully_sharded_data_parallel.py |
| # later |
| class HandleShardingStrategy(Enum): |
| FULL_SHARD = auto() |
| SHARD_GRAD_OP = auto() |
| NO_SHARD = auto() |
| |
| |
| class HandleTrainingState(Enum): |
| IDLE = auto() |
| FORWARD = auto() |
| BACKWARD_PRE = auto() |
| BACKWARD_POST = auto() |
| SUMMON_FULL_PARAMS = auto() |
| |
| |
| @dataclass |
| class HandleConfig: |
| sharding_strategy: HandleShardingStrategy |
| offload_params: bool |
| param_dtype: Optional[torch.dtype] |
| reduce_dtype: Optional[torch.dtype] |
| |
| |
| class FlatParameter(nn.Parameter): |
| """ |
| This is the flattened parameter used by :class:`FullyShardedDataParallel`. |
| It is comprised of one or more original parameters, which are flattened |
| and concatenated to construct the flattened parameter. |
| |
| Under the current design, this parameter logically represents both the |
| unsharded and sharded flattened parameter, and its data changes storages |
| dynamically. |
| - In the :class:`FullyShardedDataParallel` constructor, the parameter |
| is initialized as unsharded and then sharded in-place. |
| - At runtime, the parameter is lazily (re)-initialized. The sharded |
| parameter data is saved in ``self._local_shard``, and a new ``Tensor`` |
| ``self._full_param_padded`` is created, which is the all-gather |
| destination and owns the unsharded parameter storage thereafter. (See |
| :meth:`FullyShardedDataParallel._init_param_attributes`.) |
| - Throughout runtime, the parameter data changes storages as needed, |
| e.g. to the sharded flattened parameter, reduced-precision sharded |
| flattened parameter, or the unsharded flattened parameter. |
| |
| Attributes: |
| _unpadded_unsharded_size (torch.Size): Unsharded flattened parameter's |
| size without padding. |
| _padded_unsharded_size (torch.Size): Unsharded flattened parameter's |
| size with padding. This is only set for sharded strategies since |
| they require padding for the all-gather. |
| |
| _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info |
| entry; see :class:`ParamInfo`. |
| _numels (Tuple[int, ...]): Each parameter's numel. |
| _shapes (Tuple[torch.Size, ...]): Each parameter's shape. |
| _prefixed_param_names (Tuple[str, ...]): Each parameter's name prefixed |
| with the parent module names starting from the module passed to |
| construct this flattened parameter via :class:`FlatParamHandle`; |
| the prefixed names are guaranteed to be unique within the subtree |
| rooted in that module. |
| _num_params (int): Number of original parameters flattened into this |
| flattened parameter; this is the length of ``_param_infos``, |
| ``_numels``, ``_shapes``, and ``_prefixed_param_names``. |
| _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter |
| info entries; see :class:`SharedParamInfo`. |
| |
| _shard_param_offsets (List[Tuple[int, int])): [start, end] offsets (in |
| units of numel) giving this rank's part of each flattened original |
| module parameter; for any parameter ``p`` that is not sharded |
| across ranks, this will be [0, ``p.numel()``-1]. |
| _shard_indices (Tuple[int, int]): [start, end] indices (in units of |
| parameters) for this rank's shard of the original model parameters, |
| where the parameters follow the order in which they were originally |
| flattened; this indexes appropriately into any data structure that |
| follows the flattening order (e.g. ``_param_infos``, ``_numels``, |
| etc.). |
| _shard_numel_padded (int): Numel padded for this rank's sharded |
| flattened parameter. |
| |
| _local_shard (Tensor): Sharded flattened parameter with padding if |
| using a sharded strategy. If using ``NO_SHARD``, then this is the |
| unpadded unsharded flattened parameter, and there is no notion of a |
| sharded flattened parameter or padded unsharded flattened |
| parameter. |
| _full_param_padded (Tensor): Unsharded flattened parameter with |
| padding. This is not defined for ``NO_SHARD``. When using mixed |
| precision for parameters, this has the low precision. |
| _full_prec_full_param_padded (Tensor): Full precision unsharded |
| flattened parameter with padding. This is used for unsharding |
| outside of computation when using mixed precision for parameters. |
| This is never defined for ``NO_SHARD``. |
| _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]): |
| Flattened parameter's :class:`AccumulateGrad` object and |
| post-backward hook handle. |
| _mp_shard (Tensor): Low precision sharded flattened parameter with |
| padding. This is only defined when parameter mixed precision is |
| enabled. For ``NO_SHARD``, this is used for computation. |
| _cpu_grad (Tensor): Sharded gradient with padding stored on CPU. |
| This is only defined when offloading parameters is enabled. |
| _saved_grad_shard (Tensor): Sharded gradient with padding from previous |
| iterations for gradient accumulation without :meth:`no_sync`. |
| """ |
| |
| def _init_metadata( |
| self, |
| param_infos: List[ParamInfo], |
| numels: List[int], |
| shapes: List[torch.Size], |
| prefixed_param_names: List[str], |
| shared_param_infos: List[SharedParamInfo], |
| ) -> None: |
| """ |
| Initializes attributes holding metadata about the original parameters |
| comprising the flattened parameter. |
| |
| We expose this method separate from the constructor to keep the |
| constructor only responsible for the flattened parameter's tensor data. |
| This method should only be called once per model, while the constructor |
| may be called multiple times, e.g. when reloading from a checkpoint, in |
| which case only the tensor data needs to be passed to the constructor. |
| Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the |
| metadata is correctly assumed to be unchanged. |
| |
| Args: |
| See the Attributes in the class docstring. |
| """ |
| assert len(param_infos) == len(numels) |
| assert len(param_infos) == len(shapes) |
| assert len(param_infos) == len(prefixed_param_names) |
| self._num_params = len(param_infos) |
| self._param_infos = tuple(param_infos) |
| self._numels = tuple(numels) |
| self._shapes = tuple(shapes) |
| self._prefixed_param_names = tuple(prefixed_param_names) |
| self._shared_param_infos = tuple(shared_param_infos) |
| self._unpadded_unsharded_size = self.size() |
| _set_fsdp_flattened(self) |
| |
| |
| class FlatParamHandle: |
| """ |
| This handle manages a flattened parameter (:class:`FlatParameter`). This |
| includes sharding and view management. |
| |
| Args: |
| params (Sequence[nn.Parameter]): The parameters to use for the |
| flattened parameter. |
| module (nn.Module): A module that is the root of the subtree containing |
| all parameters in ``params``; for non-recursive wrapping, this must |
| be the top-level module, while for recursive wrapping, this may not |
| necessarily be the top-level module. |
| device (torch.device): The compute and communication device, which |
| should be a non-CPU device. We refer to it as the compute device. |
| config (HandleConfig): A config customizing the handle based on FSDP's |
| available features. |
| """ |
| |
| ################## |
| # INITIALIZATION # |
| ################## |
| def __init__( |
| self, |
| params: Sequence[nn.Parameter], |
| module: nn.Module, |
| device: torch.device, |
| config: HandleConfig, |
| ) -> None: |
| super().__init__() |
| self.device = device |
| self._config = config |
| self._training_state = HandleTrainingState.IDLE |
| self._init_flat_param(params, module) |
| self._unflatten(as_params=False) |
| |
| def _init_flat_param( |
| self, |
| params: Sequence[Optional[nn.Parameter]], |
| module: nn.Module, |
| ) -> None: |
| """ |
| Initializes the flattened parameter ``self.flat_param`` by flattening |
| the parameters in ``params`` into a single :class:`FlatParameter` and |
| saves relevant metadata. Shared parameters are only included in the |
| flattened parameter once. |
| |
| This checks that all comprising parameters have the same dtype and |
| ``requires_grad`` and does not support nested construction of |
| :class:`FlatParameter` s. |
| |
| Args: |
| See the Args in the class docstring. |
| """ |
| params_set = set(params) |
| params_set.discard(None) |
| assert ( |
| len(params_set) > 0 |
| ), "Cannot initialize a `FlatParameter` from an empty parameter list" |
| param_infos: List[ParamInfo] = [] |
| numels: List[int] = [] |
| shapes: List[torch.Size] = [] |
| prefixed_param_names: List[str] = [] |
| shared_param_infos: List[SharedParamInfo] = [] |
| shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str, str]] = {} |
| params_to_flatten: List[nn.Parameter] = [] |
| dtype: Optional[torch.dtype] = None |
| requires_grad: Optional[bool] = None |
| for submodule_name, submodule in module.named_modules(): |
| for param_name, param in submodule.named_parameters(recurse=False): |
| if param not in params_set: |
| continue |
| if param in shared_param_memo: |
| prim_module, prim_module_name, prim_param_name = shared_param_memo[ |
| param |
| ] |
| shared_param_infos.append( |
| SharedParamInfo( |
| param_name, |
| submodule, |
| submodule_name, |
| prim_param_name, |
| prim_module, |
| prim_module_name, |
| ) |
| ) |
| else: |
| if type(param) is FlatParameter: |
| raise ValueError("`FlatParameter` does not support nesting") |
| if dtype is not None and param.dtype != dtype: |
| raise ValueError( |
| "`FlatParameter` requires uniform dtype but got " |
| f"{dtype} and {param.dtype}" |
| ) |
| if dtype is None and not param.is_floating_point(): |
| raise ValueError("Integer parameters are unsupported") |
| if ( |
| requires_grad is not None |
| and param.requires_grad != requires_grad |
| ): |
| raise ValueError( |
| "`FlatParameter` requires uniform `requires_grad`" |
| ) |
| dtype = param.dtype |
| requires_grad = param.requires_grad |
| shared_param_memo[param] = (submodule, submodule_name, param_name) |
| params_to_flatten.append(param) |
| param_infos.append(ParamInfo(param_name, submodule, submodule_name)) |
| numels.append(param.numel()) |
| shapes.append(param.shape) |
| prefixed_param_name = ( |
| submodule_name + "." + param_name |
| if submodule_name |
| else param_name |
| ) |
| prefixed_param_names.append(prefixed_param_name) |
| assert requires_grad is not None |
| self.flat_param = FlatParamHandle.flatten_params( |
| params_to_flatten, requires_grad |
| ) |
| self.flat_param._init_metadata( |
| param_infos, |
| numels, |
| shapes, |
| prefixed_param_names, |
| shared_param_infos, |
| ) |
| |
| @staticmethod |
| def flatten_params( |
| params: Sequence[torch.Tensor], |
| requires_grad: bool, |
| ) -> FlatParameter: |
| """ |
| Flattens the parameters in ``params`` into a single |
| :class:`FlatParameter`. This should be the only way used to construct |
| :class:`FlatParameter` s. |
| |
| We expose this factory method for checkpointing (e.g. sharded state |
| dict). The flattened parameter's metadata should only be initialized |
| once (see :meth:`_init_metadata`), but its tensor data may be reloaded. |
| """ |
| with torch.no_grad(): |
| flat_params = [ |
| p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1) |
| for p in params |
| ] |
| flat_param_data = torch.cat(flat_params, dim=0) |
| flat_param = FlatParameter(flat_param_data, requires_grad=requires_grad) |
| return flat_param |
| |
| ################################### |
| # SHARD INITIALIZATION & METADATA # |
| ################################### |
| @torch.no_grad() |
| def shard(self, process_group: dist.ProcessGroup): |
| """ |
| Shards the handle's ``FlatParameter``. In terms of memory, this |
| allocates new memory for the sharded flattened parameter and frees the |
| unsharded flattened parameter's storage. |
| |
| Postcondition: ``self.flat_param`` is the sharded flattened parameter. |
| ``process_group``, ``rank``, and ``world_size`` attributes are set. |
| |
| TODO (awgu): Once we retire ``FlattenParamsWrapper``, we should pass |
| the process group directly to the ``FlatParamHandle`` constructor. For |
| now, we decouple ``FlattenParamsWrapper` from a process group, but this |
| makes the process-group-related attributes not necessarily defined. |
| """ |
| if not self.uses_sharded_strategy: |
| return |
| flat_param = self.flat_param |
| self.process_group = process_group |
| self.rank = process_group.rank() |
| self.world_size = process_group.size() |
| assert ( |
| flat_param.storage_offset() == 0 |
| ), "The `FlatParameter` is not the sole occupant of its storage" |
| orig_storage = flat_param.storage() |
| local_shard, numel_padded = FlatParamHandle._get_shard( |
| flat_param, self.rank, self.world_size |
| ) |
| flat_param.set_(local_shard) # type: ignore[call-overload] |
| self._init_shard_metadata(local_shard.numel(), numel_padded, self.rank) |
| if orig_storage.size() > 0: |
| orig_storage.resize_(0) |
| |
| def _init_shard_metadata( |
| self, |
| sharded_flat_param_numel: int, |
| numel_padded: int, |
| rank: int, |
| ) -> None: |
| """ |
| Initializes shard-related metadata for this rank's shard of the |
| flattened parameter: ``_shard_param_offsets``, ``_shard_indices``, and |
| ``_shard_numel_padded``. |
| |
| Args: |
| sharded_flat_param_numel (int): Numel of each rank's sharded |
| flattened parameter with padding (i.e. including |
| ``numel_padded``). |
| numel_padded (int): Numel padded for this rank's sharded flattened |
| parameter. |
| rank (int): Caller's rank. |
| """ |
| if numel_padded > sharded_flat_param_numel: |
| raise ValueError( |
| f"Sharded flattened parameter with {sharded_flat_param_numel} " |
| f"numel cannot have {numel_padded} numel padded" |
| ) |
| start = sharded_flat_param_numel * rank |
| end = sharded_flat_param_numel * (rank + 1) - 1 # inclusive |
| ( |
| self.flat_param._shard_param_offsets, # type: ignore[attr-defined] |
| self.flat_param._shard_indices, # type: ignore[attr-defined] |
| ) = self._get_shard_metadata(start, end) |
| self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] |
| |
| def _get_shard_metadata( |
| self, |
| start: int, |
| end: int, |
| ) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: |
| """ |
| Computes the shard metadata based on ``start`` and ``end``, which give |
| the closed interval of the unsharded flattened parameter specifying the |
| shard. |
| |
| Args: |
| start (int): Start index (in units of numel) of this rank's shard |
| of the flattened parameter. |
| end (int): End index (in units of numel and inclusive) of this |
| rank's shard of the flattened parameter. |
| |
| Return: |
| Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: See |
| ``_shard_param_offsets`` and ``_shard_indices`` in |
| :class:`FlatParameter` 's docstring. |
| """ |
| flat_param_offsets = self._get_flat_param_offsets() |
| # Indices of the original parameters in this rank's sharded flattened |
| # parameter |
| shard_param_indices_range = [] # elements will be consecutive |
| # [start, end] offsets giving this rank's part of the flattened |
| # original module parameter (which will be [0, `p.numel()`-1] for any |
| # parameter that is not sharded across ranks) |
| shard_param_offsets = [] |
| for i, (param_start, param_end) in enumerate(flat_param_offsets): |
| if start > param_end or end < param_start: |
| continue |
| if start <= param_start: |
| intra_param_start = 0 |
| else: |
| intra_param_start = start - param_start |
| intra_param_end = min(param_end, end) - param_start |
| shard_param_indices_range.append(i) |
| shard_param_offsets.append( |
| (intra_param_start, intra_param_end) |
| ) # both inclusive |
| if len(shard_param_indices_range) == 0: |
| shard_param_indices = (0, 0) |
| assert len(shard_param_offsets) == 0 |
| else: |
| shard_param_indices = ( |
| shard_param_indices_range[0], |
| shard_param_indices_range[-1], |
| ) |
| assert ( |
| len(shard_param_offsets) |
| == shard_param_indices[-1] - shard_param_indices[0] + 1 |
| ) |
| return tuple(shard_param_offsets), shard_param_indices |
| |
| @staticmethod |
| def _get_unpadded_shard( |
| tensor: Tensor, |
| rank: int, |
| world_size: int, |
| ) -> Tuple[Tensor, int]: |
| """ |
| Returns the shard of ``tensor`` without any padding for the given |
| ``rank`` and ``world_size`` and the numel to pad for that shard. |
| |
| If ``tensor`` is already flattened or may be viewed in the flattened |
| shape (which is true in the expected usage), then this method does not |
| allocate any new tensor memory. |
| """ |
| chunks = torch.flatten(tensor).chunk(world_size) |
| if len(chunks) < (rank + 1): |
| # This rank gets an empty chunk fully padded with zeros since there |
| # are not enough chunks across ranks |
| chunk = chunks[0].new_empty(0) |
| else: |
| chunk = chunks[rank] |
| numel_to_pad = chunks[0].numel() - chunk.numel() |
| assert ( |
| numel_to_pad >= 0 |
| ), "Chunk's size should be at most the first chunk's size" |
| return chunk, numel_to_pad |
| |
| @staticmethod |
| def _get_shard( |
| tensor: Tensor, |
| rank: int, |
| world_size: int, |
| ) -> Tuple[Tensor, int]: |
| """ |
| Returns the shard of ``tensor`` with padding for the given ``rank`` and |
| ``world_size`` and the numel padded for that shard. |
| |
| This method allocates new memory (via :meth:`clone`) since the |
| unsharded ``tensor`` may be deallocated after this method returns. |
| """ |
| chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard( |
| tensor, rank, world_size |
| ) |
| shard = chunk.clone() |
| if numel_to_pad > 0: |
| shard = F.pad(shard, [0, numel_to_pad]) |
| return shard, numel_to_pad |
| |
| @staticmethod |
| def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size: |
| """ |
| Returns the shape of ``tensor`` after sharding including padding. This |
| requires ``tensor`` to have 1D shape and ensures that the returned |
| shape is 1D. |
| """ |
| assert len(tensor.shape) == 1, f"{tensor.shape}" |
| unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard( |
| tensor, rank, world_size |
| ) |
| unpadded_sharded_size = unpadded_sharded_tensor.size() |
| assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}" |
| return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) |
| |
| def _get_flat_param_offsets(self) -> List[Tuple[int, int]]: |
| """Returns [start, end] offsets of each original parameter's flattened |
| data in the unsharded flattened parameter (without padding).""" |
| cumulative_sum = list(accumulate(self.flat_param._numels)) |
| starts = [0] + cumulative_sum[:-1] |
| ends = [end - 1 for end in cumulative_sum] # inclusive |
| param_offsets = list(zip(starts, ends)) |
| return param_offsets |
| |
| def shard_metadata( |
| self, |
| ) -> FlatParamShardMetadata: |
| """Returns shard-related metadata specific to this rank's shard of the |
| flattened parameter.""" |
| assert hasattr(self.flat_param, "_shard_indices") and hasattr( |
| self.flat_param, "_shard_param_offsets" |
| ), "Shard metadata has not been initialized" |
| shard_param_start_index = self.flat_param._shard_indices[0] # type: ignore[attr-defined] |
| shard_param_end_index = self.flat_param._shard_indices[1] # type: ignore[attr-defined] |
| sl = ( |
| slice(shard_param_start_index, shard_param_end_index + 1) |
| if shard_param_start_index <= shard_param_end_index |
| else slice(0, 0) |
| ) |
| return FlatParamShardMetadata( |
| self.flat_param._prefixed_param_names[sl], |
| self.flat_param._shapes[sl], |
| self.flat_param._numels[sl], |
| self.flat_param._shard_param_offsets[:], # type: ignore[attr-defined] |
| ) |
| |
| ################### |
| # UNSHARD/RESHARD # |
| ################### |
| def pre_unshard(self) -> bool: |
| """ |
| Returns: ``False`` if this is a no-op and ``True`` otherwise. |
| |
| Postcondition: ``self.flat_param`` 's data is on the device for |
| communication and is what should be all-gathered. This means that it |
| matches the dtype of the expected unsharded parameter. |
| """ |
| ret = False |
| if ( |
| self.uses_sharded_strategy |
| and not self._config.offload_params |
| and not self.needs_unshard() |
| ): |
| pass # no-op |
| elif self._uses_param_mixed_precision and not self._force_full_precision: |
| self._use_low_precision_shard() |
| ret = True |
| elif self._config.offload_params and self.flat_param.device != self.device: |
| # NOTE: This creates a new tensor distinct from any attributes. |
| self._flat_param_to(self.device, non_blocking=True) |
| ret = True |
| self._check_on_compute_device(self.flat_param) |
| return ret |
| |
| def _use_low_precision_shard(self): |
| """ |
| Allocates the low precision shard directly on the compute device and |
| switches to using the low precision sharded flattened parameter. |
| """ |
| self._check_low_precision_shard() |
| flat_param = self.flat_param |
| _alloc_storage( |
| flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined] |
| ) |
| # `copy_()` implicitly casts to the low precision |
| flat_param._mp_shard.copy_( # type: ignore[attr-defined] |
| flat_param._local_shard.to( # type: ignore[attr-defined] |
| self.device, non_blocking=True |
| ) |
| ) |
| # Invariant: `_mp_shard` is always on the compute device. |
| flat_param.data = flat_param._mp_shard # type: ignore[attr-defined] |
| |
| def unshard(self): |
| """ |
| Runs the unshard logic. This includes all-gathering the flattened |
| parameter and switching to using the unsharded flattened parameter. If |
| the handle does not need unsharding, then this only switches to using |
| the unsharded flattened parameter. For ``NO_SHARD``, this is a no-op. |
| |
| If FSDP is in :meth:`summon_full_params` and the handle uses parameter |
| mixed precision, then the parameter is forced to full precision. |
| """ |
| if not self.needs_unshard(): |
| if self.uses_sharded_strategy: |
| # The handle may have been resharded without freeing the padded |
| # unsharded flattened parameter, in which case we need to |
| # switch to using the unsharded parameter |
| unsharded_flat_param = self._get_padded_unsharded_flat_param() |
| self._use_unsharded_flat_param(unsharded_flat_param) |
| return |
| unsharded_flat_param = self._alloc_padded_unsharded_flat_param() |
| self._all_gather_flat_param(unsharded_flat_param) |
| |
| def needs_unshard(self) -> bool: |
| """Returns if the handle's flattened parameter needs to be unsharded.""" |
| if not self.uses_sharded_strategy: |
| return False |
| unsharded_flat_param = self._get_padded_unsharded_flat_param() |
| already_unsharded = ( |
| unsharded_flat_param.storage().size() == unsharded_flat_param.numel() |
| ) |
| return not already_unsharded |
| |
| def _alloc_padded_unsharded_flat_param(self): |
| """ |
| Allocates the *padded* unsharded flattened parameter. The unpadded |
| unsharded flattened parameter is always a view into the padded one. |
| This padded parameter is saved to a different attribute on the |
| ``FlatParameter`` depending on if we force full precision. |
| """ |
| self._check_sharded_strategy() |
| flat_param = self.flat_param |
| unsharded_flat_param = self._get_padded_unsharded_flat_param() |
| self._check_storage_freed(unsharded_flat_param) |
| _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined] |
| return unsharded_flat_param |
| |
| def _get_padded_unsharded_flat_param(self) -> torch.Tensor: |
| """ |
| Returns a reference to the padded unsharded flattened parameter |
| depending on the calling context. This should only be called if using a |
| sharded strategy. |
| """ |
| self._check_sharded_strategy() |
| flat_param = self.flat_param |
| if self._force_full_precision: |
| # When parameter mixed precision is enabled, we use a different |
| # tensor as the all-gather destination to preserve the invariant |
| # that `_full_param_padded` is in the low precision |
| unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] |
| p_assert( |
| unsharded_flat_param.dtype != self._config.param_dtype, |
| f"Expects full precision but got {self._config.param_dtype}", |
| ) |
| else: |
| unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined] |
| return unsharded_flat_param |
| |
| def _all_gather_flat_param( |
| self, |
| padded_unsharded_flat_param: Tensor, |
| ) -> None: |
| """ |
| All-gathers the handle's flattened parameter to the destination |
| ``padded_unsharded_flat_param``, and switches to using the all-gathered |
| tensor. |
| """ |
| p_assert( |
| hasattr(self, "process_group") and hasattr(self, "world_size"), |
| "Expects a process group and world size to have been set via `shard()`", |
| ) |
| sharded_flat_param = self.flat_param.data |
| expected_numel = sharded_flat_param.numel() * self.world_size |
| p_assert( |
| padded_unsharded_flat_param.numel() == expected_numel, |
| f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", |
| ) |
| dist._all_gather_base( |
| padded_unsharded_flat_param, |
| sharded_flat_param, |
| self.process_group, |
| ) |
| self._use_unsharded_flat_param(padded_unsharded_flat_param) |
| |
| def _use_unsharded_flat_param( |
| self, |
| padded_unsharded_flat_param: torch.Tensor, |
| ) -> None: |
| """ |
| Switches to using the *unpadded* unsharded flattened parameter, which |
| is a view into the *padded* unsharded flattened parameter. |
| """ |
| unsharded_size = self.flat_param._unpadded_unsharded_size |
| self.flat_param.data = padded_unsharded_flat_param[ |
| : unsharded_size.numel() |
| ].view(unsharded_size) |
| |
| def post_unshard(self): |
| """ |
| Runs the post-unshard logic. This includes freeing the low precision |
| shard if needed. |
| """ |
| if self._uses_param_mixed_precision and self.uses_sharded_strategy: |
| self._free_low_precision_sharded_param() |
| self._check_on_compute_device(self.flat_param) |
| |
| def _free_low_precision_sharded_param(self): |
| """Frees the low precision sharded flattened parameter.""" |
| self._check_low_precision_shard() |
| _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] |
| |
| def prepare_gradient(self): |
| """ |
| Prepares the gradient for the backward computation by saving and |
| clearing any existing sharded gradient in ``.grad`` to enable computing |
| a new unsharded gradient. |
| """ |
| p_assert( |
| self._training_state |
| in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), |
| "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", |
| ) |
| flat_param = self.flat_param |
| if flat_param.grad is not None and ( |
| flat_param.grad.size() != flat_param._unpadded_unsharded_size |
| or flat_param.grad.device != flat_param.device # grad on CPU |
| ): |
| self._check_on_compute_device(self.flat_param) |
| grad_offloaded = flat_param.grad.device != self.device |
| p_assert( |
| not grad_offloaded or self._config.offload_params, |
| f"Expects the sharded gradient to be on {self.device} " |
| f"but got {flat_param.grad.device}", |
| ) |
| prev_iter_synced_gradients = ( |
| flat_param.grad.size() |
| == flat_param._local_shard.size() # type: ignore[attr-defined] |
| ) |
| if prev_iter_synced_gradients: |
| # TODO (awgu): Gradient accumulation outside `no_sync()` |
| # does not work with CPU offloading. The issue should be |
| # that, in the post-backward hook, we cannot do an addition |
| # between a CPU tensor (the existing sharded gradient) and |
| # a GPU tensor (the new sharded gradient). |
| if not grad_offloaded: |
| flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] |
| else: |
| padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] |
| p_assert( |
| flat_param.grad.size() == padded_unsharded_size, |
| "Expects `.grad` to be the unsharded gradient in " |
| f"`no_sync()` with size {padded_unsharded_size} " |
| f"but got size {flat_param.grad.size()}", |
| ) |
| flat_param.grad = None |
| |
| @contextlib.contextmanager |
| def to_cpu(self): |
| """ |
| Moves the unpadded unsharded flattened parameter to CPU while in the |
| context and moves it back to the previous device upon exit. For now, |
| this assumes the ``FlatParameter`` is the unpadded unsharded flattened |
| parameter since (1) there is no reason to include the padding in the |
| copy and (2) there is no use case for the sharded flattened parameter. |
| |
| Precondition: ``self.flat_param`` 's data is the unpadded unsharded |
| flattened parameter on the compute device, and the handle uses a |
| sharded strategy. |
| Postcondition: Same as the precondition. |
| """ |
| self._check_sharded_strategy() |
| p_assert( |
| self.flat_param.size() == self.flat_param._unpadded_unsharded_size, |
| f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", |
| ) |
| self._check_on_compute_device(self.flat_param) |
| # Check that the unpadded unsharded flattened parameter is a view into |
| # the padded unsharded flattened parameter as expected |
| # NOTE: This check is not strictly needed for correctness but is a |
| # useful sanity check since the tensor should only be used internally. |
| unpadded_storage_ptr = self.flat_param.storage().data_ptr() |
| padded_storage_ptr = ( |
| self._get_padded_unsharded_flat_param().storage().data_ptr() |
| ) |
| p_assert( |
| unpadded_storage_ptr == padded_storage_ptr, |
| "Expects the unpadded parameter to be a view into the padded parameter", |
| ) |
| self._flat_param_to(torch.device("cpu")) |
| self._free_unsharded_flat_param() |
| try: |
| yield |
| finally: |
| p_assert( |
| self.flat_param.size() == self.flat_param._unpadded_unsharded_size, |
| f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", |
| ) |
| padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param() |
| # Copy from CPU to the compute device |
| padded_unsharded_flat_param[: self.flat_param.numel()].copy_( |
| self.flat_param |
| ) |
| self._use_unsharded_flat_param(padded_unsharded_flat_param) |
| |
| def reshard(self, free_unsharded_flat_param: bool): |
| """ |
| Runs the reshard logic. This includes freeing the unsharded flattened |
| parameter if ``free_unsharded_flat_param`` and switching to using the |
| sharded flattened parameter. |
| """ |
| if free_unsharded_flat_param: |
| self._free_unsharded_flat_param() |
| self._use_sharded_flat_param() |
| |
| def post_reshard(self): |
| """ |
| Runs the post-reshard logic. This includes freeing any memory that |
| can now be freed given that the ``FlatParameter`` points to the full |
| precision sharded flattened parameter. |
| |
| Precondition: ``self.flat_param`` 's data points to the full precision |
| sharded flattened parameter. |
| """ |
| # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since |
| # it is also the low precision *unsharded* flattened parameter. Hence, |
| # we delay the free until the reshard. |
| if ( |
| self._uses_param_mixed_precision |
| and not self.uses_sharded_strategy |
| and not self._force_full_precision # did not use the low precision shard |
| ): |
| self._free_low_precision_sharded_param() |
| |
| def _free_unsharded_flat_param(self): |
| """ |
| Frees the padded unsharded flattened parameter. The tensor to free |
| depends on the calling context since the unshard may have forced full |
| precision, in which case a different tensor is used. |
| """ |
| self._check_sharded_strategy() |
| unsharded_flat_param = self._get_padded_unsharded_flat_param() |
| self._check_storage_allocated(unsharded_flat_param) |
| self._check_on_compute_device(unsharded_flat_param) |
| # Do not free the memory until all ops in the current stream finish |
| unsharded_flat_param.record_stream( |
| cast(torch._C.Stream, torch.cuda.current_stream()) |
| ) |
| _free_storage(unsharded_flat_param) |
| |
| def _use_sharded_flat_param(self) -> None: |
| """Switches to using the sharded flattened parameter.""" |
| flat_param = self.flat_param |
| if self._config.offload_params: |
| device = flat_param._local_shard.device # type: ignore[attr-defined] |
| p_assert( |
| device == torch.device("cpu"), |
| f"Expects the local shard to be on CPU but got {device}", |
| ) |
| flat_param.data = flat_param._local_shard # type: ignore[attr-defined] |
| |
| ######### |
| # VIEWS # |
| ######### |
| @staticmethod |
| def _get_unflat_views( |
| flat_param: FlatParameter, |
| tensor: Optional[torch.Tensor] = None, |
| ) -> Iterator[Tensor]: |
| """ |
| Returns unflattened ``Tensor`` views into ``tensor`` if it is not |
| ``None`` or ``flat_param`` otherwise, where the unflattening is based |
| on ``flat_param`` 's metadata. |
| |
| In other words, to get views into the unsharded flattened parameter, |
| pass ``tensor`` as ``None``, but to get views into tensor optimizer |
| state, pass ``tensor`` as the optimizer state tensor. |
| """ |
| if tensor is None: |
| tensor = flat_param |
| p_assert( |
| tensor.numel() == flat_param._unpadded_unsharded_size.numel(), |
| f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got " |
| f"{tensor.numel()} numel", |
| ) |
| views = ( |
| subtensor.view(shape) |
| for (subtensor, shape) in zip( |
| torch.split(tensor, flat_param._numels, dim=0), flat_param._shapes # type: ignore[arg-type] |
| ) |
| ) |
| return views |
| |
| def _unflatten(self, as_params: bool) -> None: |
| """ |
| Unflattens the unsharded flattened parameter by setting the original |
| module parameter variables to be views into it. |
| |
| Args: |
| as_params (bool): If ``True``, then registers the original |
| parameters as ``nn.Parameter`` s; if ``False``, then registers |
| the original parameters only as ``Tensor`` s. ``False`` should |
| be used during forward/backward computation and when hiding the |
| original parameters from :meth:`nn.Module.named_parameters`. |
| """ |
| views = self._get_unflat_views(self.flat_param) |
| for view, (param_name, module, _) in zip(views, self.flat_param._param_infos): |
| if hasattr(module, param_name): |
| delattr(module, param_name) |
| if as_params: |
| module.register_parameter(param_name, nn.Parameter(view)) |
| else: |
| setattr(module, param_name, view) |
| for ( |
| param_name, |
| module, |
| _, |
| prim_param_name, |
| prim_module, |
| _, |
| ) in self.flat_param._shared_param_infos: |
| if hasattr(module, param_name): |
| delattr(module, param_name) |
| assert hasattr(prim_module, prim_param_name) |
| param: Union[Tensor, nn.Parameter] = getattr(prim_module, prim_param_name) |
| if as_params: |
| assert isinstance(param, nn.Parameter) |
| module.register_parameter(param_name, param) |
| else: |
| setattr(module, param_name, param) |
| |
| @contextlib.contextmanager |
| def unflatten_as_params(self) -> Generator: |
| """ |
| Assumes the flattened parameter is unsharded. When in the context, |
| unflattens the original parameters as ``nn.Parameter`` views into the |
| flattened parameter, and after the context, restores the original |
| parameters as ``Tensor`` views into the flattened parameter. |
| """ |
| self._unflatten(as_params=True) |
| try: |
| yield |
| finally: |
| self._unflatten(as_params=False) |
| |
| ########### |
| # HELPERS # |
| ########### |
| def _flat_param_to(self, *args, **kwargs): |
| """Wraps an in-place call to ``.to()`` for ``self.flat_param``.""" |
| self.flat_param.data = self.flat_param.to(*args, **kwargs) |
| |
| def _get_modules(self) -> Set[nn.Module]: |
| """Returns a :class:`set` of the modules whose parameters are included |
| in this handle's flattened parameter.""" |
| return set(pi.module for pi in self.flat_param._param_infos).union( |
| set(spi.module for spi in self.flat_param._shared_param_infos) |
| ) |
| |
| def parameter_module_names(self) -> Iterator[Tuple[str, str]]: |
| shared_param_infos = [ |
| ParamInfo(param_name, module, module_name) |
| for ( |
| param_name, |
| module, |
| module_name, |
| _, |
| _, |
| _, |
| ) in self.flat_param._shared_param_infos |
| ] |
| for param_name, _, module_name in chain( |
| self.flat_param._param_infos, shared_param_infos |
| ): |
| yield (param_name, module_name) |
| |
| ####################### |
| # CHECKS & INVARIANTS # |
| ####################### |
| def _check_sharded_strategy(self): |
| p_assert(self.uses_sharded_strategy, "Expects sharded strategy") |
| |
| def _check_on_compute_device(self, tensor: Tensor): |
| p_assert( |
| tensor.device == self.device, |
| f"Expects tensor to be on the compute device {self.device}", |
| ) |
| |
| @staticmethod |
| def _check_storage_freed(tensor: Tensor): |
| storage_size: int = tensor.storage().size() |
| p_assert( |
| storage_size == 0, |
| f"Expects storage to be freed but got storage with size {storage_size}", |
| ) |
| |
| @staticmethod |
| def _check_storage_allocated(tensor: Tensor): |
| storage_size: int = tensor.storage().size() |
| p_assert(storage_size > 0, "Expects storage to be allocated") |
| |
| def _check_low_precision_shard(self): |
| p_assert( |
| self._uses_param_mixed_precision, |
| "Not using low precision for parameters", |
| ) |
| p_assert( |
| getattr(self.flat_param, "_mp_shard", None) is not None, |
| "Expects `_mp_shard` to exist", |
| ) |
| device = self.flat_param._mp_shard.device # type: ignore[attr-defined] |
| p_assert( |
| device == self.device, |
| f"Expects the low precision shard to be on {self.device} but got {device}", |
| ) |
| |
| ############## |
| # PROPERTIES # |
| ############## |
| @property |
| def uses_sharded_strategy(self) -> bool: |
| return self._config.sharding_strategy != HandleShardingStrategy.NO_SHARD |
| |
| @property |
| def _uses_param_mixed_precision(self) -> bool: |
| return self._config.param_dtype is not None |
| |
| @property |
| def _force_full_precision(self) -> bool: |
| return ( |
| self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS |
| and self._uses_param_mixed_precision |
| ) |