| """ |
| This file includes public APIs for FSDP such as the classes used for the |
| constructor arguments. |
| """ |
| |
| from dataclasses import dataclass |
| from enum import auto, Enum |
| |
| from typing import Optional, Sequence, Type |
| |
| import torch |
| from torch.nn.modules.batchnorm import _BatchNorm |
| |
| __all__ = [ |
| "ShardingStrategy", |
| "BackwardPrefetch", |
| "MixedPrecision", |
| "CPUOffload", |
| "StateDictType", |
| "StateDictConfig", |
| "FullStateDictConfig", |
| "LocalStateDictConfig", |
| "ShardedStateDictConfig", |
| "OptimStateDictConfig", |
| "FullOptimStateDictConfig", |
| "LocalOptimStateDictConfig", |
| "ShardedOptimStateDictConfig", |
| "StateDictSettings", |
| ] |
| |
| |
| class ShardingStrategy(Enum): |
| """ |
| This specifies the sharding strategy to be used for distributed training by |
| :class:`FullyShardedDataParallel`. |
| |
| - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded. |
| For the parameters, this strategy unshards (via all-gather) before the |
| forward, reshards after the forward, unshards before the backward |
| computation, and reshards after the backward computation. For gradients, |
| it synchronizes and shards them (via reduce-scatter) after the backward |
| computation. The sharded optimizer states are updated locally per rank. |
| - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during |
| computation, and additionally, parameters are sharded outside |
| computation. For the parameters, this strategy unshards before the |
| forward, does not reshard them after the forward, and only reshards them |
| after the backward computation. The sharded optimizer states are updated |
| locally per rank. Inside ``no_sync()``, the parameters are not resharded |
| after the backward computation. |
| - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded |
| but instead replicated across ranks similar to PyTorch's |
| :class:`DistributedDataParallel` API. For gradients, this strategy |
| synchronizes them (via all-reduce) after the backward computation. The |
| unsharded optimizer states are updated locally per rank. |
| - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across |
| nodes. This results in reduced communication volume as expensive all-gathers and |
| reduce-scatters are only done within a node, which can be more performant for medium |
| -sized models. |
| - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across |
| nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput |
| since the unsharded parameters are not freed after the forward pass, saving the |
| all-gathers in the pre-backward. |
| """ |
| |
| FULL_SHARD = auto() |
| SHARD_GRAD_OP = auto() |
| NO_SHARD = auto() |
| HYBRID_SHARD = auto() |
| _HYBRID_SHARD_ZERO2 = auto() |
| |
| |
| class BackwardPrefetch(Enum): |
| """ |
| This configures explicit backward prefetching, which improves throughput by |
| enabling communication and computation overlap in the backward pass at the |
| cost of slightly increased memory usage. |
| |
| - ``BACKWARD_PRE``: This enables the most overlap but increases memory |
| usage the most. This prefetches the next set of parameters *before* the |
| current set of parameters' gradient computation. This overlaps the *next |
| all-gather* and the *current gradient computation*, and at the peak, it |
| holds the current set of parameters, next set of parameters, and current |
| set of gradients in memory. |
| - ``BACKWARD_POST``: This enables less overlap but requires less memory |
| usage. This prefetches the next set of parameters *after* the current |
| set of parameters' gradient computation. This overlaps the *current |
| reduce-scatter* and the *next gradient computation*, and it frees the |
| current set of parameters before allocating memory for the next set of |
| parameters, only holding the next set of parameters and current set of |
| gradients in memory at the peak. |
| - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables |
| the backward prefetching altogether. This has no overlap and does not |
| increase memory usage. In general, we do not recommend this setting since |
| it may degrade throughput significantly. |
| |
| For more technical context: For a single process group using NCCL backend, |
| any collectives, even if issued from different streams, contend for the |
| same per-device NCCL stream, which implies that the relative order in which |
| the collectives are issued matters for overlapping. The two backward |
| prefetching values correspond to different issue orders. |
| """ |
| |
| # NOTE: For both modes, the ordering that defines "current" and "next" is |
| # not always exact in the current implementation. A mistargeted prefetch |
| # simply means that the parameter memory is allocated earlier than needed, |
| # possibly increasing peak memory usage, but does not affect correctness. |
| BACKWARD_PRE = auto() |
| BACKWARD_POST = auto() |
| |
| |
| @dataclass |
| class MixedPrecision: |
| """ |
| This configures FSDP-native mixed precision training. |
| |
| Attributes: |
| param_dtype (Optional[torch.dtype]): This specifies the dtype for model |
| parameters during forward and backward and thus the dtype for |
| forward and backward computation. Outside forward and backward, the |
| *sharded* parameters are kept in full precision (e.g. for the |
| optimizer step), and for model checkpointing, the parameters are |
| always saved in full precision. (Default: ``None``) |
| reduce_dtype (Optional[torch.dtype]): This specifies the dtype for |
| gradient reduction (i.e. reduce-scatter or all-reduce). If this is |
| ``None`` but ``param_dtype`` is not ``None``, then this takes on |
| the ``param_dtype`` value, still running gradient reduction in low |
| precision. This is permitted to differ from ``param_dtype``, e.g. |
| to force gradient reduction to run in full precision. (Default: |
| ``None``) |
| buffer_dtype (Optional[torch.dtype]): This specifies the dtype for |
| buffers. FSDP does not shard buffers. Rather, FSDP casts them to |
| ``buffer_dtype`` in the first forward pass and keeps them in that |
| dtype thereafter. For model checkpointing, the buffers are saved |
| in full precision except for ``LOCAL_STATE_DICT``. (Default: |
| ``None``) |
| keep_low_precision_grads (bool): If ``False``, then FSDP upcasts |
| gradients to full precision after the backward pass in preparation |
| for the optimizer step. If ``True``, then FSDP keeps the gradients |
| in the dtype used for gradient reduction, which can save memory if |
| using a custom optimizer that supports running in low precision. |
| (Default: ``False``) |
| cast_forward_inputs (bool): If ``True``, then this FSDP module casts |
| its forward args and kwargs to ``param_dtype``. This is to ensure |
| that parameter and input dtypes match for forward computation, as |
| required by many ops. This may need to be set to ``True`` when only |
| applying mixed precision to some but not all FSDP modules, in which |
| case a mixed-precision FSDP submodule needs to recast its inputs. |
| (Default: ``False``) |
| cast_root_forward_inputs (bool): If ``True``, then the root FSDP module |
| casts its forward args and kwargs to ``param_dtype``, overriding |
| the value of ``cast_forward_inputs``. For non-root FSDP modules, |
| this does not do anything. (Default: ``True``) |
| _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies |
| module classes to ignore for mixed precision when using an |
| ``auto_wrap_policy``: Modules of these classes will have FSDP |
| applied to them separately with mixed precision disabled (meaning |
| that the final FSDP construction would deviate from the specified |
| policy). If ``auto_wrap_policy`` is not specified, then this does |
| not do anything. This API is experimental and subject to change. |
| (Default: ``(_BatchNorm,)``) |
| |
| .. note:: This API is experimental and subject to change. |
| |
| .. note:: Only floating point tensors are cast to their specified dtypes. |
| |
| .. note:: In ``summon_full_params``, parameters are forced to full |
| precision, but buffers are not. |
| |
| .. note:: Layer norm and batch norm accumulate in ``float32`` even when |
| their inputs are in a low precision like ``float16`` or ``bfloat16``. |
| Disabling FSDP's mixed precision for those norm modules only means that |
| the affine parameters are kept in ``float32``. However, this incurs |
| separate all-gathers and reduce-scatters for those norm modules, which |
| may be inefficient, so if the workload permits, the user should prefer |
| to still apply mixed precision to those modules. |
| |
| .. note:: By default, if the user passes a model with any ``_BatchNorm`` |
| modules and specifies an ``auto_wrap_policy``, then the batch norm |
| modules will have FSDP applied to them separately with mixed precision |
| disabled. See the ``_module_classes_to_ignore`` argument. |
| |
| .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and |
| ``cast_forward_inputs=False`` by default. For the root FSDP instance, |
| its ``cast_root_forward_inputs`` takes precedence over its |
| ``cast_forward_inputs``. For non-root FSDP instances, their |
| ``cast_root_forward_inputs`` values are ignored. The default setting is |
| sufficient for the typical case where each FSDP instance has the same |
| ``MixedPrecision`` configuration and only needs to cast inputs to the |
| ``param_dtype`` at the beginning of the model's forward pass. |
| |
| .. note:: For nested FSDP instances with different ``MixedPrecision`` |
| configurations, we recommend setting individual ``cast_forward_inputs`` |
| values to configure casting inputs or not before each instance's |
| forward. In such a case, since the casts happen before each FSDP |
| instance's forward, a parent FSDP instance should have its non-FSDP |
| submodules run before its FSDP submodules to avoid the activation dtype |
| being changed due to a different ``MixedPrecision`` configuration. |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined variables") |
| >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) |
| >>> model[1] = FSDP( |
| >>> model[1], |
| >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), |
| >>> ) |
| >>> model = FSDP( |
| >>> model, |
| >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), |
| >>> ) |
| |
| The above shows a working example. On the other hand, if ``model[1]`` |
| were replaced with ``model[0]``, meaning that the submodule using |
| different ``MixedPrecision`` ran its forward first, then ``model[1]`` |
| would incorrectly see ``float16`` activations instead of ``bfloat16`` |
| ones. |
| |
| """ |
| |
| param_dtype: Optional[torch.dtype] = None |
| reduce_dtype: Optional[torch.dtype] = None |
| buffer_dtype: Optional[torch.dtype] = None |
| keep_low_precision_grads: bool = False |
| cast_forward_inputs: bool = False |
| cast_root_forward_inputs: bool = True |
| _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,) |
| |
| |
| @dataclass |
| class CPUOffload: |
| """ |
| This configures CPU offloading. |
| |
| Attributes: |
| offload_params (bool): This specifies whether to offload parameters to |
| CPU when not involved in computation. If ``True``, then this |
| offloads gradients to CPU as well, meaning that the optimizer step |
| runs on CPU. |
| """ |
| |
| offload_params: bool = False |
| |
| |
| class StateDictType(Enum): |
| """ |
| This enum indicates that which type of ``state_dict`` the FSDP module is |
| currently processing (returning or loading). |
| The default value is FULL_STATE_DICT to comply the PyTorch convention. |
| ..note:: |
| FSDP currently supports three types of ``state_dict``: |
| 1. ``state_dict/load_state_dict`: this pair of APIs return and load |
| the non-sharded, unflattened parameters. The semantics is the |
| same as using DDP. |
| 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return |
| and load local sharded, flattened parameters. The values returned |
| by ``_local_state_dict`` can be directly used by FSDP and is only |
| meaningful to FSDP (because parameters are flattened). Note that |
| these APIs are meant for use via the :func:`state_dict_type` |
| context manager as follows: |
| >>> # xdoctest: +SKIP("undefined variables") |
| >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): |
| ... state = fsdp.state_dict() # loads local state dict |
| 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs |
| return and load sharded, unflattened parameters. The ``state_dict`` |
| return by ``sharded_state_dict`` can be used by all other parallel |
| schemes (resharding may be required). |
| """ |
| |
| FULL_STATE_DICT = auto() |
| LOCAL_STATE_DICT = auto() |
| SHARDED_STATE_DICT = auto() |
| |
| |
| @dataclass |
| class StateDictConfig: |
| """ |
| ``StateDictConfig`` is the base class for all ``state_dict`` configuration |
| classes. Users should instantiate a child class (e.g. |
| ``FullStateDictConfig``) in order to configure settings for the |
| corresponding ``state_dict`` type supported by FSDP. |
| |
| Attributes: |
| offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict |
| values to CPU, and if ``False``, then FSDP keeps them on GPU. |
| (Default: ``False``) |
| """ |
| |
| offload_to_cpu: bool = False |
| |
| |
| @dataclass |
| class FullStateDictConfig(StateDictConfig): |
| """ |
| ``FullStateDictConfig`` is a config class meant to be used with |
| ``StateDictType.FULL_STATE_DICT``. We recommend enabling both |
| ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state |
| dicts to save GPU memory and CPU memory, respectively. This config class |
| is meant to be used via the :func:`state_dict_type` context manager as |
| follows: |
| |
| >>> # xdoctest: +SKIP("undefined variables") |
| >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| >>> fsdp = FSDP(model, auto_wrap_policy=...) |
| >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): |
| >>> state = fsdp.state_dict() |
| >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. |
| >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: |
| >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP |
| >>> if dist.get_rank() == 0: |
| >>> # Load checkpoint only on rank 0 to avoid memory redundancy |
| >>> state_dict = torch.load("my_checkpoint.pt") |
| >>> model.load_state_dict(state_dict) |
| >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument |
| >>> # communicates loaded checkpoint states from rank 0 to rest of the world. |
| >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) |
| >>> # After this point, all ranks have FSDP model with loaded checkpoint. |
| |
| Attributes: |
| rank0_only (bool): If ``True``, then only rank 0 saves the full state |
| dict, and nonzero ranks save an empty dict. If ``False``, then all |
| ranks save the full state dict. (Default: ``False``) |
| """ |
| |
| rank0_only: bool = False |
| |
| |
| @dataclass |
| class LocalStateDictConfig(StateDictConfig): |
| pass |
| |
| |
| @dataclass |
| class ShardedStateDictConfig(StateDictConfig): |
| """ |
| ``ShardedStateDictConfig`` is a config class meant to be used with |
| ``StateDictType.SHARDED_STATE_DICT``. |
| |
| Attributes: |
| _use_dtensor (bool): If ``True``, then FSDP saves the state dict values |
| as ``DTensor``, and if ``False``, then FSDP saves them as |
| ``ShardedTensor``. (Default: ``False``) |
| |
| .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig` |
| and it is used by FSDP to determine the type of state dict values. Users should not |
| manually modify ``_use_dtensor``. |
| """ |
| |
| _use_dtensor: bool = False |
| |
| |
| @dataclass |
| class OptimStateDictConfig: |
| """ |
| ``OptimStateDictConfig`` is the base class for all ``optim_state_dict`` |
| configuration classes. Users should instantiate a child class (e.g. |
| ``FullOptimStateDictConfig``) in order to configure settings for the |
| corresponding ``optim_state_dict`` type supported by FSDP. |
| |
| Attributes: |
| offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's |
| tensor values to CPU, and if ``False``, then FSDP keeps them on the |
| original device (which is GPU unless parameter CPU offloading is |
| enabled). (Default: ``True``) |
| """ |
| |
| offload_to_cpu: bool = True |
| |
| |
| @dataclass |
| class FullOptimStateDictConfig(OptimStateDictConfig): |
| """ |
| Attributes: |
| rank0_only (bool): If ``True``, then only rank 0 saves the full state |
| dict, and nonzero ranks save an empty dict. If ``False``, then all |
| ranks save the full state dict. (Default: ``False``) |
| """ |
| |
| rank0_only: bool = False |
| |
| |
| @dataclass |
| class LocalOptimStateDictConfig(OptimStateDictConfig): |
| offload_to_cpu: bool = False |
| |
| |
| @dataclass |
| class ShardedOptimStateDictConfig(OptimStateDictConfig): |
| """ |
| ``ShardedOptimStateDictConfig`` is a config class meant to be used with |
| ``StateDictType.SHARDED_STATE_DICT``. |
| |
| Attributes: |
| _use_dtensor (bool): If ``True``, then FSDP saves the state dict values |
| as ``DTensor``, and if ``False``, then FSDP saves them as |
| ``ShardedTensor``. (Default: ``False``) |
| |
| .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig` |
| and it is used by FSDP to determine the type of state dict values. Users should not |
| manually modify ``_use_dtensor``. |
| """ |
| |
| _use_dtensor: bool = False |
| |
| |
| @dataclass |
| class StateDictSettings: |
| state_dict_type: StateDictType |
| state_dict_config: StateDictConfig |
| optim_state_dict_config: OptimStateDictConfig |