[FSDP] ufmt /fsdp (#87811)
This applies `ufmt` to all of the FSDP files in the `torch/distributed/fsdp/` directory.
**Test Plan**
CI
**Notes**
For VSCode users,
- Install `ufmt`: https://pypi.org/project/ufmt/
- Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
- Include in `settings.json`:
```
{
"[python]": {
"editor.defaultFormatter": "omnilib.ufmt",
"editor.formatOnSave": true,
},
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87811
Approved by: https://github.com/rohan-varma, https://github.com/fegin
diff --git a/torch/distributed/fsdp/_fsdp_extensions.py b/torch/distributed/fsdp/_fsdp_extensions.py
index abe0d90..1f087f4 100644
--- a/torch/distributed/fsdp/_fsdp_extensions.py
+++ b/torch/distributed/fsdp/_fsdp_extensions.py
@@ -5,7 +5,6 @@
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard
-
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py
index a5e1ab6..f87f871 100644
--- a/torch/distributed/fsdp/_optim_utils.py
+++ b/torch/distributed/fsdp/_optim_utils.py
@@ -3,6 +3,7 @@
import functools
from typing import (
Any,
+ cast,
Dict,
Iterable,
Iterator,
@@ -12,18 +13,18 @@
Sequence,
Tuple,
Union,
- cast,
)
import torch
import torch.distributed as dist
+
# Import the entire FSDP file to avoid circular imports
import torch.distributed.fsdp.fully_sharded_data_parallel as FSDP
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle
-from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
@@ -298,9 +299,9 @@
unflat_osd_state = unflat_osd["state"]
for param, unflat_param_names in param_to_unflat_param_names.items():
if isinstance(param, FlatParameter): # flatten FSDP parameters' states
- assert param in flat_param_to_fsdp_module, (
- f"Check the `flat_param_to_fsdp_module` construction\nparam: {param}"
- )
+ assert (
+ param in flat_param_to_fsdp_module
+ ), f"Check the `flat_param_to_fsdp_module` construction\nparam: {param}"
fsdp_module = flat_param_to_fsdp_module[param]
flat_state = _flatten_optim_state(
unflat_osd_state,
diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py
index b0382b4..0cc9dd6 100644
--- a/torch/distributed/fsdp/_shard_utils.py
+++ b/torch/distributed/fsdp/_shard_utils.py
@@ -250,10 +250,8 @@
requires_grad=False,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
- )
+ ),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
- local_shards,
- sharded_tensor_metadata=sharded_tensor_metadata,
- process_group=pg
+ local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
)
diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py
index ed4b8f2..90083ef 100644
--- a/torch/distributed/fsdp/_state_dict_utils.py
+++ b/torch/distributed/fsdp/_state_dict_utils.py
@@ -6,25 +6,24 @@
import torch
import torch.distributed as dist
import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
+
# Import the entire FSDP file to avoid circular imports
import torch.distributed.fsdp.fully_sharded_data_parallel as FSDP
import torch.nn as nn
import torch.nn.functional as F
-
from torch.distributed._shard.sharded_tensor import (
+ init_from_local_shards,
Shard,
ShardedTensor,
- init_from_local_shards,
)
-from torch.distributed.utils import (
- _replace_by_prefix,
-)
+from torch.distributed.utils import _replace_by_prefix
-from ._fsdp_extensions import _ext_chunk_tensor, _ext_pre_load_state_dict_transform
-from ._fsdp_extensions import _extensions as _user_extensions
-from .flat_param import (
- FlatParamHandle,
+from ._fsdp_extensions import (
+ _ext_chunk_tensor,
+ _ext_pre_load_state_dict_transform,
+ _extensions as _user_extensions,
)
+from .flat_param import FlatParamHandle
def _full_post_state_dict_hook(
@@ -53,16 +52,12 @@
# exiting `summon_full_params()` via the parameter shape. However, for
# `NO_SHARD`, we cannot tell from the shape, so we do not return early.
if (
- (
- not module._use_orig_params
- and FSDP.FLAT_PARAM in module.module._parameters
- )
- or (
- module._use_orig_params
- and module._handles
- and module._handles[0].uses_sharded_strategy
- and module._handles[0].is_sharded(module._handles[0].flat_param)
- )
+ not module._use_orig_params and FSDP.FLAT_PARAM in module.module._parameters
+ ) or (
+ module._use_orig_params
+ and module._handles
+ and module._handles[0].uses_sharded_strategy
+ and module._handles[0].is_sharded(module._handles[0].flat_param)
):
return state_dict
@@ -79,7 +74,7 @@
# do not have prefix considered as they are not computed in `state_dict`
# call.
if clean_key.startswith(clean_prefix):
- clean_key = clean_key[len(clean_prefix):]
+ clean_key = clean_key[len(clean_prefix) :]
# Clone non-ignored parameters before exiting the
# `_summon_full_params()` context
@@ -88,8 +83,9 @@
f"only has {state_dict.keys()}. prefix={prefix}, "
f"module_name={module_name} param_name={param_name} rank={module.rank}."
)
- if clean_key not in module._ignored_param_names and \
- not getattr(state_dict[fqn], "_has_been_cloned", False):
+ if clean_key not in module._ignored_param_names and not getattr(
+ state_dict[fqn], "_has_been_cloned", False
+ ):
try:
state_dict[fqn] = state_dict[fqn].clone().detach()
state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
@@ -129,11 +125,9 @@
) -> None:
# We do not expect to be calling pre-hooks twice without post-hook
# call in between.
- assert getattr(module, '_full_param_ctx', None) is None
+ assert getattr(module, "_full_param_ctx", None) is None
# Note that it needs writeback=True to persist.
- module._full_param_ctx = module._summon_full_params(
- recurse=False, writeback=True
- )
+ module._full_param_ctx = module._summon_full_params(recurse=False, writeback=True)
module._full_param_ctx.__enter__()
_replace_by_prefix(state_dict, prefix, prefix + f"{FSDP.FSDP_PREFIX}")
@@ -141,7 +135,7 @@
def _full_post_load_state_dict_hook(module, *args, **kwargs) -> None:
# We should exit summon_full_params context.
module._assert_state([FSDP.TrainingState_.SUMMON_FULL_PARAMS])
- assert getattr(module, '_full_param_ctx', None) is not None
+ assert getattr(module, "_full_param_ctx", None) is not None
module._full_param_ctx.__exit__(None, None, None)
module._full_param_ctx = None
@@ -189,7 +183,9 @@
def _local_pre_load_state_dict_hook(
- module, state_dict: Dict[str, Any], prefix: str,
+ module,
+ state_dict: Dict[str, Any],
+ prefix: str,
) -> None:
"""
This hook finds the local flat_param for this FSDP module from the
@@ -253,7 +249,7 @@
rank=module.rank,
world_size=module.world_size,
num_devices_per_node=torch.cuda.device_count(),
- pg=module.process_group
+ pg=module.process_group,
)
if module._state_dict_config.offload_to_cpu:
sharded_tensor = sharded_tensor.cpu()
@@ -271,7 +267,9 @@
def _sharded_pre_load_state_dict_hook(
- module, state_dict: Dict[str, Any], prefix: str,
+ module,
+ state_dict: Dict[str, Any],
+ prefix: str,
) -> None:
"""
The hook combines the unflattened, sharded parameters (ShardedTensor) to
@@ -331,7 +329,9 @@
# Get the chunk from the loaded flat_param for the local rank.
loaded_flat_tensor, num_to_pad = FlatParamHandle._get_shard(
- loaded_flat_param, module.rank, module.world_size,
+ loaded_flat_param,
+ module.rank,
+ module.world_size,
)
loaded_flat_tensor.to(flat_param.device)
assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), (
@@ -377,10 +377,7 @@
# back to their mixed precision type. This is because buffers are cast
# during lazy_init() and stay at their mixed precision type before/after
# forward/backward. As a result state_dict() should maintain this.
- if (
- fsdp_module._is_root
- and fsdp_module._mixed_precision_enabled_for_buffers()
- ):
+ if fsdp_module._is_root and fsdp_module._mixed_precision_enabled_for_buffers():
fsdp_module._cast_buffers(recurse=True)
return processed_state_dict
diff --git a/torch/distributed/fsdp/_symbolic_trace.py b/torch/distributed/fsdp/_symbolic_trace.py
index 026595f..f6fe5e4 100644
--- a/torch/distributed/fsdp/_symbolic_trace.py
+++ b/torch/distributed/fsdp/_symbolic_trace.py
@@ -5,7 +5,6 @@
import torch
-
__all__ = ["TracingConfig"]
@@ -140,13 +139,18 @@
if args is not None:
named_params: List[Tuple[str, torch.nn.Parameter]] = []
for arg in args:
- if isinstance(arg, torch.fx.Proxy) and arg.node.target in prefixed_param_name_to_param:
+ if (
+ isinstance(arg, torch.fx.Proxy)
+ and arg.node.target in prefixed_param_name_to_param
+ ):
param = prefixed_param_name_to_param[arg.node.target]
named_params.append((arg.node.target, param))
if param not in set(execution_info.param_exec_order):
execution_info.param_exec_order.append(param)
if named_params:
- execution_info.module_to_execution_infos[module].append((module, named_params))
+ execution_info.module_to_execution_infos[module].append(
+ (module, named_params)
+ )
elif kind == "call_module":
named_params = list(module.named_parameters())
if named_params:
@@ -234,7 +238,10 @@
)
prefixed_param_name_to_param = dict(root_module.named_parameters())
tracer.create_proxy = functools.partial(
- _patched_create_proxy, original_create_proxy, execution_info, prefixed_param_name_to_param
+ _patched_create_proxy,
+ original_create_proxy,
+ execution_info,
+ prefixed_param_name_to_param,
)
try:
yield
diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py
index bd37ce5..eb72042 100644
--- a/torch/distributed/fsdp/_utils.py
+++ b/torch/distributed/fsdp/_utils.py
@@ -10,14 +10,11 @@
)
from torch.nn.utils.rnn import PackedSequence
-
FSDP_FLATTENED = "_fsdp_flattened"
def _contains_batchnorm(module):
- return any(
- isinstance(mod, _BatchNorm) for mod in module.modules()
- )
+ return any(isinstance(mod, _BatchNorm) for mod in module.modules())
def _override_batchnorm_mixed_precision(module):
@@ -27,11 +24,14 @@
def _apply_to_tensors(
- fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
+ fn: Callable,
+ container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
- def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]) -> Any:
+ def apply(
+ x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
+ ) -> Any:
if torch.is_tensor(x):
return fn(x)
elif hasattr(x, "__dataclass_fields__"):
@@ -75,6 +75,7 @@
module prefix name (e.g. "module.submodule." just like in model state dict)
and makes that available to ``module_fn``.
"""
+
def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
# Call the module function before recursing over children (pre-order)
module_fn(module, prefix, *args, **kwargs)
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 266dc80..3e4eca0 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -33,7 +33,6 @@
p_assert,
)
-
__all__ = [
"FlatParameter",
"FlatParamHandle",
@@ -1507,7 +1506,8 @@
# memory and owns the gradient storage, so it will never
# require gradient writeback.
flat_param_grad = (
- flat_param.grad if self.uses_sharded_strategy or not self._config.offload_params
+ flat_param.grad
+ if self.uses_sharded_strategy or not self._config.offload_params
else flat_param._cpu_grad # type: ignore[attr-defined]
)
needs_grad_writeback = flat_param_grad is None or not _same_storage(
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 5fb2e5c..8cd1847 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -8,10 +8,11 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
-from enum import Enum, auto
+from enum import auto, Enum
from typing import (
Any,
Callable,
+ cast,
Deque,
Dict,
Generator,
@@ -22,7 +23,6 @@
Set,
Tuple,
Union,
- cast,
)
import torch
@@ -35,15 +35,10 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
)
-from torch.distributed.algorithms._comm_hooks import (
- LOW_PRECISION_HOOKS,
- default_hooks,
-)
+from torch.distributed.algorithms._comm_hooks import default_hooks, LOW_PRECISION_HOOKS
from torch.distributed.distributed_c10d import _get_default_group
-from torch.distributed.utils import (
- _sync_params_and_buffers,
- _to_kwargs,
-)
+from torch.distributed.utils import _sync_params_and_buffers, _to_kwargs
+
from ._optim_utils import (
_broadcast_pos_dim_tensor_states,
_broadcast_processed_optim_state_dict,
@@ -57,9 +52,9 @@
_rekey_sharded_optim_state_dict,
)
from ._state_dict_utils import (
+ _post_load_state_dict_hook,
_post_state_dict_hook,
_pre_load_state_dict_hook,
- _post_load_state_dict_hook,
)
from ._utils import (
_apply_to_modules,
@@ -78,10 +73,10 @@
HandleTrainingState,
)
from .wrap import (
- ParamExecOrderWrapPolicy,
_or_policy,
_recursive_wrap,
_wrap_batchnorm_individually,
+ ParamExecOrderWrapPolicy,
)
_TORCHDISTX_AVAIL = True
@@ -94,18 +89,23 @@
if not hasattr(torch, "fx"):
_TORCH_FX_AVAIL = False
if _TORCH_FX_AVAIL:
- from ._symbolic_trace import (
- TracingConfig,
- _init_execution_info,
- _patch_tracer,
- )
+ from ._symbolic_trace import _init_execution_info, _patch_tracer, TracingConfig
__all__ = [
- "FullyShardedDataParallel", "ShardingStrategy", "MixedPrecision",
- "CPUOffload", "BackwardPrefetch", "StateDictType", "StateDictConfig",
- "FullStateDictConfig", "LocalStateDictConfig", "ShardedStateDictConfig",
- "OptimStateKeyType", "TrainingState_", "clean_tensor_name",
+ "FullyShardedDataParallel",
+ "ShardingStrategy",
+ "MixedPrecision",
+ "CPUOffload",
+ "BackwardPrefetch",
+ "StateDictType",
+ "StateDictConfig",
+ "FullStateDictConfig",
+ "LocalStateDictConfig",
+ "ShardedStateDictConfig",
+ "OptimStateKeyType",
+ "TrainingState_",
+ "clean_tensor_name",
]
@@ -148,6 +148,7 @@
``NO_SHARD`` inter-node.
"""
+
FULL_SHARD = auto()
SHARD_GRAD_OP = auto()
NO_SHARD = auto()
@@ -197,6 +198,7 @@
would occur in the `param_dtype` precision, if given, otherwise, in the
original parameter precision.
"""
+
# maintain a tensor of this dtype that the fp32 param shard will be cast to.
# Will control the precision of model params, inputs, and thus compute as
# well.
@@ -309,6 +311,7 @@
order to configure settings for the particular type of ``state_dict``
implementation FSDP will use.
"""
+
offload_to_cpu: bool = False
@@ -340,6 +343,7 @@
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
"""
+
rank0_only: bool = False
@@ -366,9 +370,10 @@
class _ExecOrderWarnStatus(Enum):
"""Used internally for execution order validation."""
- NONE = auto() # no deviation yet
+
+ NONE = auto() # no deviation yet
WARNING = auto() # deviated this iteration; currently issuing warnings
- WARNED = auto() # deviated in a previous iteration
+ WARNED = auto() # deviated in a previous iteration
class _ExecOrderData:
@@ -403,9 +408,10 @@
self._forward_prefetch_limit = forward_prefetch_limit
# Data structures for execution order validation
- self._checking_order: bool = (
- debug_level in [dist.DebugLevel.INFO, dist.DebugLevel.DETAIL]
- )
+ self._checking_order: bool = debug_level in [
+ dist.DebugLevel.INFO,
+ dist.DebugLevel.DETAIL,
+ ]
self.process_group: Optional[dist.ProcessGroup] = None
self.world_size: Optional[int] = None
self.all_handles: List[FlatParamHandle] = []
@@ -454,7 +460,9 @@
prefetch given the current handles key. If there are no valid handles
keys to prefetch, then this returns an empty :class:`list`.
"""
- current_index = self.handles_to_post_forward_order_index.get(current_handles_key, None)
+ current_index = self.handles_to_post_forward_order_index.get(
+ current_handles_key, None
+ )
if current_index is None:
return None
target_index = current_index - 1
@@ -462,9 +470,7 @@
for _ in range(self._backward_prefetch_limit):
if target_index < 0:
break
- target_handles_keys.append(
- self.handles_post_forward_order[target_index]
- )
+ target_handles_keys.append(self.handles_post_forward_order[target_index])
target_index -= 1
return target_handles_keys
@@ -477,7 +483,9 @@
prefetch given the current handles key. If there are no valid handles
keys to prefetch, then this returns an empty :class:`list`.
"""
- current_index = self.handles_to_pre_forward_order_index.get(current_handles_key, None)
+ current_index = self.handles_to_pre_forward_order_index.get(
+ current_handles_key, None
+ )
if current_index is None:
return None
target_index = current_index + 1
@@ -485,9 +493,7 @@
for _ in range(self._forward_prefetch_limit):
if target_index >= len(self.handles_pre_forward_order):
break
- target_handles_keys.append(
- self.handles_pre_forward_order[target_index]
- )
+ target_handles_keys.append(self.handles_pre_forward_order[target_index])
target_index += 1
return target_handles_keys
@@ -511,7 +517,9 @@
self.handles_to_post_forward_order_index[handles_key] = index
self.handles_post_forward_order.append(handles_key)
- def record_pre_forward(self, handles: List[FlatParamHandle], is_training: bool) -> None:
+ def record_pre_forward(
+ self, handles: List[FlatParamHandle], is_training: bool
+ ) -> None:
"""
Records ``handles`` in the pre-forward order, where ``handles`` should
be a group of handles used in the same module's forward. If ``handles``
@@ -597,7 +605,7 @@
(
rank,
world_indices[
- rank * num_valid_indices: (rank + 1) * num_valid_indices
+ rank * num_valid_indices : (rank + 1) * num_valid_indices
],
)
for rank in range(self.world_size)
@@ -683,7 +691,9 @@
continue
handle = self.all_handles[index]
flat_param = handle.flat_param
- prefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param])
+ prefixed_param_names.append(
+ self.flat_param_to_prefixed_param_names[flat_param]
+ )
return prefixed_param_names
def _get_names_from_handles(
@@ -700,7 +710,9 @@
flat_param = handle.flat_param
if flat_param not in self.flat_param_to_prefixed_param_names:
continue
- prefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param])
+ prefixed_param_names.append(
+ self.flat_param_to_prefixed_param_names[flat_param]
+ )
return prefixed_param_names
def next_iter(self):
@@ -970,6 +982,7 @@
the sharded strategies that schedule all-gathers. Enabling this can
help lower the number of CUDA malloc retries.
"""
+
def __init__(
self,
module: nn.Module,
@@ -1062,10 +1075,16 @@
self._buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
self._check_single_device_module(module, ignored_params)
- device_from_device_id: Optional[torch.device] = self._get_device_from_device_id(device_id)
- self._materialize_module(module, param_init_fn, ignored_params, device_from_device_id)
+ device_from_device_id: Optional[torch.device] = self._get_device_from_device_id(
+ device_id
+ )
+ self._materialize_module(
+ module, param_init_fn, ignored_params, device_from_device_id
+ )
self._move_module_to_device(module, ignored_params, device_from_device_id)
- self.compute_device = self._get_compute_device(module, ignored_params, device_from_device_id)
+ self.compute_device = self._get_compute_device(
+ module, ignored_params, device_from_device_id
+ )
params_to_flatten = list(self._get_orig_params(module, ignored_params))
if sync_module_states:
self._sync_module_states(module, params_to_flatten)
@@ -1098,7 +1117,10 @@
self.params.append(handle.flat_param)
self._register_param_handle(handle)
handle.shard()
- if self.cpu_offload.offload_params and handle.flat_param.device != torch.device("cpu"):
+ if (
+ self.cpu_offload.offload_params
+ and handle.flat_param.device != torch.device("cpu")
+ ):
handle.flat_param_to(torch.device("cpu"))
if not use_orig_params:
self._check_orig_params_flattened(ignored_params)
@@ -1301,8 +1323,7 @@
self,
device_id: Optional[Union[int, torch.device]],
) -> Optional[torch.device]:
- """
- """
+ """ """
if device_id is None:
return None
device = (
@@ -1341,11 +1362,15 @@
``reset_parameters()``, and for torchdistX fake tensors, this calls
``deferred_init.materialize_module()``.
"""
- is_meta_module = any(p.is_meta for p in self._get_orig_params(module, ignored_params))
+ is_meta_module = any(
+ p.is_meta for p in self._get_orig_params(module, ignored_params)
+ )
is_torchdistX_deferred_init = (
not is_meta_module
and _TORCHDISTX_AVAIL
- and any(fake.is_fake(p) for p in self._get_orig_params(module, ignored_params))
+ and any(
+ fake.is_fake(p) for p in self._get_orig_params(module, ignored_params)
+ )
)
if (
is_meta_module or is_torchdistX_deferred_init
@@ -1357,7 +1382,9 @@
param_init_fn(module)
elif is_meta_module:
# Run default meta device initialization
- materialization_device = device_from_device_id or torch.cuda.current_device()
+ materialization_device = (
+ device_from_device_id or torch.cuda.current_device()
+ )
module.to_empty(device=materialization_device)
try:
with torch.no_grad():
@@ -1483,7 +1510,10 @@
module_states.append(buffer.detach())
module_states.extend(param.detach() for param in params)
_sync_params_and_buffers(
- self.process_group, module_states, _PARAM_BROADCAST_BUCKET_SIZE, src=0,
+ self.process_group,
+ module_states,
+ _PARAM_BROADCAST_BUCKET_SIZE,
+ src=0,
)
def _get_orig_params(
@@ -1573,7 +1603,7 @@
p_assert(
len(handles) == len(free_unsharded_flat_params),
"Expects both lists to have equal length but got "
- f"{len(handles)} and {len(free_unsharded_flat_params)}"
+ f"{len(handles)} and {len(free_unsharded_flat_params)}",
)
for handle, free_unsharded_flat_param in zip(
handles,
@@ -1651,9 +1681,10 @@
the input ``module``.
"""
return [
- submodule for submodule in module.modules()
- if isinstance(submodule, FullyShardedDataParallel) and
- (not root_only or submodule.check_is_root())
+ submodule
+ for submodule in module.modules()
+ if isinstance(submodule, FullyShardedDataParallel)
+ and (not root_only or submodule.check_is_root())
]
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
@@ -1728,6 +1759,7 @@
precision given by ``dtype``, while respecting the existing
``requires_grad`` on the tensors.
"""
+
def cast_fn(x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x):
return x
@@ -1741,7 +1773,7 @@
with torch.no_grad():
return (
_apply_to_tensors(cast_fn, args),
- _apply_to_tensors(cast_fn, kwargs)
+ _apply_to_tensors(cast_fn, kwargs),
)
def _cast_buffers(
@@ -1775,9 +1807,15 @@
if memo is None:
memo = set()
for module in self.modules():
- if module is not self and isinstance(module, FullyShardedDataParallel) and recurse:
+ if (
+ module is not self
+ and isinstance(module, FullyShardedDataParallel)
+ and recurse
+ ):
# Allow any child FSDP instances to handle their own buffers.
- module._cast_buffers(device=device, dtype=dtype, memo=memo, recurse=recurse)
+ module._cast_buffers(
+ device=device, dtype=dtype, memo=memo, recurse=recurse
+ )
elif module not in memo:
memo.add(module)
for name, buf in module.named_buffers(recurse=False):
@@ -1863,7 +1901,9 @@
fsdp_module.limit_all_gathers = self.limit_all_gathers
fsdp_module._free_event_queue = self._free_event_queue
fsdp_module._handles_prefetched = self._handles_prefetched
- fsdp_module._needs_pre_backward_unshard = self._needs_pre_backward_unshard
+ fsdp_module._needs_pre_backward_unshard = (
+ self._needs_pre_backward_unshard
+ )
for handle in fsdp_module._handles:
fsdp_module._init_param_attributes(handle)
if inconsistent_limit_all_gathers:
@@ -1936,13 +1976,11 @@
# fwd/bwd, it is freed and we only hold on to the full precision shard.
# As a result, this reduced precision shard is not allocated if we are
# not in the forward/backward pass.
- if (
- self._mixed_precision_enabled_for_params()
- ):
+ if self._mixed_precision_enabled_for_params():
p._mp_shard = torch.zeros_like(
p._local_shard,
device=self.compute_device,
- dtype=self.mixed_precision.param_dtype
+ dtype=self.mixed_precision.param_dtype,
)
_free_storage(p._mp_shard)
@@ -1957,7 +1995,8 @@
# into full_param_padded it can occur without issues and result in
# full_param_padded having the expected param_dtype.
full_param_dtype = (
- p.dtype if not self._mixed_precision_enabled_for_params()
+ p.dtype
+ if not self._mixed_precision_enabled_for_params()
else self.mixed_precision.param_dtype
)
p._full_param_padded = torch.zeros( # type: ignore[attr-defined]
@@ -2024,7 +2063,9 @@
for handles_key in handles_to_prefetch:
# Prefetch the next set of handles without synchronizing to allow
# the sync to happen as late as possible to maximize overlap
- self._unshard(handles_key, self._streams["unshard"], self._streams["pre_unshard"])
+ self._unshard(
+ handles_key, self._streams["unshard"], self._streams["pre_unshard"]
+ )
self._handles_prefetched[handles_key] = True
def _get_handles_to_prefetch(
@@ -2048,33 +2089,31 @@
p_assert(
training_state in valid_training_states,
f"Prefetching is only supported in {valid_training_states} but "
- f"currently in {training_state}"
+ f"currently in {training_state}",
)
eod = self._exec_order_data
target_handles_keys: List[_HandlesKey] = []
if (
- (
- training_state == HandleTrainingState.BACKWARD_PRE
- and self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
- )
- or (
- training_state == HandleTrainingState.BACKWARD_POST
- and self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
- )
+ training_state == HandleTrainingState.BACKWARD_PRE
+ and self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
+ ) or (
+ training_state == HandleTrainingState.BACKWARD_POST
+ and self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
):
target_handles_keys = [
- target_handles_key for target_handles_key in
- eod.get_handles_to_backward_prefetch(current_handles_key)
+ target_handles_key
+ for target_handles_key in eod.get_handles_to_backward_prefetch(
+ current_handles_key
+ )
if self._needs_pre_backward_unshard.get(target_handles_key, False)
and not self._handles_prefetched.get(target_handles_key, False)
]
- elif (
- training_state == HandleTrainingState.FORWARD
- and self.forward_prefetch
- ):
+ elif training_state == HandleTrainingState.FORWARD and self.forward_prefetch:
target_handles_keys = [
- target_handles_key for target_handles_key in
- eod.get_handles_to_forward_prefetch(current_handles_key)
+ target_handles_key
+ for target_handles_key in eod.get_handles_to_forward_prefetch(
+ current_handles_key
+ )
if self._needs_pre_forward_unshard.get(target_handles_key, False)
and not self._handles_prefetched.get(target_handles_key, False)
]
@@ -2089,7 +2128,7 @@
training_states = set(handle._training_state for handle in handles_key)
p_assert(
len(training_states) == 1,
- f"Expects uniform training state but got {training_states}"
+ f"Expects uniform training state but got {training_states}",
)
return next(iter(training_states))
@@ -2157,7 +2196,9 @@
"All FSDP modules should have the same type of state_dict_config."
)
- expected_state_dict_config_type = _state_dict_type_to_config[state_dict_type]
+ expected_state_dict_config_type = _state_dict_type_to_config[
+ state_dict_type
+ ]
if expected_state_dict_config_type != type(state_dict_config):
raise RuntimeError(
f"Expected state_dict_config of type {expected_state_dict_config_type} "
@@ -2200,10 +2241,11 @@
prev_state_dict_type = None
prev_state_dict_config = None
try:
- prev_state_dict_type, prev_state_dict_config = (
- FullyShardedDataParallel.set_state_dict_type(
- module, state_dict_type, state_dict_config
- )
+ (
+ prev_state_dict_type,
+ prev_state_dict_config,
+ ) = FullyShardedDataParallel.set_state_dict_type(
+ module, state_dict_type, state_dict_config
)
yield
except Exception as e:
@@ -2233,18 +2275,14 @@
def _param_fqns(self) -> Iterator[Tuple[str, str, str]]:
if not self._has_params:
return
- for param_name, module_name in (
- self._handles[0].parameter_module_names()
- ):
+ for param_name, module_name in self._handles[0].parameter_module_names():
module_name = self._convert_to_wrapped_module_name(module_name)
fqn = f"{module_name}{param_name}"
yield fqn, param_name, module_name
@property
def _shared_param_fqns(self) -> Iterator[Tuple[str, str, str]]:
- for param_name, module_name in (
- self._handles[0].shared_parameter_module_names()
- ):
+ for param_name, module_name in self._handles[0].shared_parameter_module_names():
module_name = self._convert_to_wrapped_module_name(module_name)
fqn = f"{module_name}{param_name}"
yield fqn, param_name, module_name
@@ -2297,17 +2335,21 @@
if self._state_dict_type == StateDictType.FULL_STATE_DICT:
# Get config args
full_state_dict_config = (
- self._state_dict_config if self._state_dict_config is not None
+ self._state_dict_config
+ if self._state_dict_config is not None
else FullStateDictConfig()
)
rank0_only = full_state_dict_config.rank0_only
offload_to_cpu = full_state_dict_config.offload_to_cpu
summon_ctx = (
self._summon_full_params(
- recurse=False, writeback=False, offload_to_cpu=offload_to_cpu, rank0_only=rank0_only
+ recurse=False,
+ writeback=False,
+ offload_to_cpu=offload_to_cpu,
+ rank0_only=rank0_only,
)
- if self.training_state != TrainingState_.SUMMON_FULL_PARAMS else
- contextlib.suppress()
+ if self.training_state != TrainingState_.SUMMON_FULL_PARAMS
+ else contextlib.suppress()
)
with summon_ctx:
# Since buffers are not sharded and stay casted, restore them to their
@@ -2316,10 +2358,7 @@
# buffers stay casted after forward/backward. We must have the
# call here instead of above because _summon_full_params itself
# calls _lazy_init() which would cast the buffers.
- if (
- self._is_root
- and self._mixed_precision_enabled_for_buffers()
- ):
+ if self._is_root and self._mixed_precision_enabled_for_buffers():
self._cast_buffers(
dtype=self._buffer_name_to_orig_dtype, recurse=False
)
@@ -2332,13 +2371,10 @@
return {}
elif (
- self._state_dict_type == StateDictType.LOCAL_STATE_DICT or
- self._state_dict_type == StateDictType.SHARDED_STATE_DICT
+ self._state_dict_type == StateDictType.LOCAL_STATE_DICT
+ or self._state_dict_type == StateDictType.SHARDED_STATE_DICT
):
- if (
- self._has_params and
- not self._handles[0].uses_sharded_strategy
- ):
+ if self._has_params and not self._handles[0].uses_sharded_strategy:
raise RuntimeError(
"sharded_state_dict/local_state_dict can only be called "
"when parameters are flatten and sharded."
@@ -2352,17 +2388,22 @@
Runs the forward pass for the wrapped module, inserting FSDP-specific
pre- and post-forward sharding logic.
"""
- with torch.autograd.profiler.record_function("FullyShardedDataParallel.forward"):
+ with torch.autograd.profiler.record_function(
+ "FullyShardedDataParallel.forward"
+ ):
self._lazy_init()
args, kwargs = self._fsdp_root_pre_forward(*args, **kwargs)
unused = None
- unshard_fn = functools.partial(self._pre_forward_unshard, handles=self._handles)
+ unshard_fn = functools.partial(
+ self._pre_forward_unshard, handles=self._handles
+ )
# Do not free the root's parameters in the post-forward for
# `FULL_SHARD` with the intention that they are immediately used
# for backward computation (though this may not be true)
free_unsharded_flat_params = [
not self._is_root
- and handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
+ and handle._config.sharding_strategy
+ == HandleShardingStrategy.FULL_SHARD
for handle in self._handles
]
reshard_fn = functools.partial(
@@ -2375,7 +2416,7 @@
p_assert(
handle.flat_param.device == self.compute_device,
"Expected `FlatParameter` to be on the compute device "
- f"{self.compute_device} but got {handle.flat_param.device}"
+ f"{self.compute_device} but got {handle.flat_param.device}",
)
output = self._fsdp_wrapped_module(*args, **kwargs)
return self._post_forward(self._handles, reshard_fn, unused, unused, output)
@@ -2418,7 +2459,9 @@
) -> None:
"""Unshards parameters in the pre-forward."""
if handles:
- self._unshard(handles, self._streams["unshard"], self._streams["pre_unshard"])
+ self._unshard(
+ handles, self._streams["unshard"], self._streams["pre_unshard"]
+ )
handles_key = tuple(handles)
self._needs_pre_forward_unshard[handles_key] = False
torch.cuda.current_stream().wait_stream(self._streams["unshard"])
@@ -2476,7 +2519,9 @@
if self._mixed_precision_enabled_for_params():
input_dtype = self.mixed_precision.param_dtype
args, kwargs = self._cast_fp_inputs_to_dtype(
- input_dtype, *args, **kwargs,
+ input_dtype,
+ *args,
+ **kwargs,
)
return args, kwargs
@@ -2525,7 +2570,7 @@
offload_to_cpu: bool = False,
with_grads: bool = False,
) -> Generator:
- r""" A context manager to expose full params for FSDP instances.
+ r"""A context manager to expose full params for FSDP instances.
Can be useful *after* forward/backward for a model to get
the params for additional processing or checking. It can take a non-FSDP
module and will summon full params for all contained FSDP modules as
@@ -2663,7 +2708,9 @@
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
self._clear_grads_if_needed()
- free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
+ free_unsharded_flat_params = [
+ handle.needs_unshard() for handle in self._handles
+ ]
# No need to call `wait_stream()` since we unshard in the computation
# stream directly
computation_stream = torch.cuda.current_stream()
@@ -2742,7 +2789,7 @@
handle.rank,
handle.world_size,
)
- handle.flat_param._local_shard[:param_shard.numel()].copy_(param_shard)
+ handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard)
if writeback_grad:
existing_grad = handle.sharded_grad
if existing_grad is not None:
@@ -2751,7 +2798,7 @@
handle.rank,
handle.world_size,
)
- existing_grad[:grad_shard.numel()].copy_(grad_shard)
+ existing_grad[: grad_shard.numel()].copy_(grad_shard)
@contextlib.contextmanager
def _unflatten_as_params(self) -> Generator:
@@ -2825,7 +2872,7 @@
p_assert(
len(self._handles) <= 1,
"Expects <=1 handle per FSDP instance; needs to be refactored "
- "for >1 handle (e.g. non-recursive wrapping)"
+ "for >1 handle (e.g. non-recursive wrapping)",
)
if not self._handles:
return
@@ -2833,7 +2880,7 @@
p_assert(
handle._use_orig_params,
f"Inconsistent `_use_orig_params` -- FSDP: {self._use_orig_params} "
- f"handle: {handle._use_orig_params}"
+ f"handle: {handle._use_orig_params}",
)
handle._deregister_orig_params()
self._register_flat_param()
@@ -2973,7 +3020,9 @@
# If the handles have been prefetched, this `_unshard()` simply
# switches to using the unsharded parameter
- self._unshard(_handles, self._streams["unshard"], self._streams["pre_unshard"])
+ self._unshard(
+ _handles, self._streams["unshard"], self._streams["pre_unshard"]
+ )
torch.cuda.current_stream().wait_stream(self._streams["unshard"])
# Set this to `False` to ensure that a mistargeted prefetch
@@ -3022,7 +3071,7 @@
p_assert(
temp_flat_param.grad_fn is not None,
"The `grad_fn` is needed to access the `AccumulateGrad` and "
- "register the post-backward hook"
+ "register the post-backward hook",
)
acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
hook_handle = acc_grad.register_hook(
@@ -3055,11 +3104,16 @@
):
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
- self._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST])
+ self._assert_state(
+ [TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST]
+ )
self.training_state = TrainingState_.BACKWARD_POST
handle._training_state = HandleTrainingState.BACKWARD_POST
- if self._use_param_exec_order_policy() and self._param_exec_order_prep_stage:
+ if (
+ self._use_param_exec_order_policy()
+ and self._param_exec_order_prep_stage
+ ):
# In self._fsdp_params_exec_order, the parameters are ordered based on
# the execution order in the backward pass in the first iteration.
self._fsdp_params_exec_order.append(param)
@@ -3103,7 +3157,9 @@
# TODO: Make this a communication hook when communication hooks
# are implemented for FSDP. Note that this is a noop if the
# reduce_dtype matches the param dtype.
- param.grad.data = param.grad.data.to(self.mixed_precision.reduce_dtype)
+ param.grad.data = param.grad.data.to(
+ self.mixed_precision.reduce_dtype
+ )
if self._exec_order_data.is_first_iter:
# For all sharding strategies communication is performed through `_communication_hook`:
@@ -3112,11 +3168,11 @@
# and `_communication_hook_state`, required for communication not `None`.`
p_assert(
self._communication_hook is not None,
- "Communication hook should not be None"
+ "Communication hook should not be None",
)
p_assert(
self._communication_hook_state is not None,
- "Communication hook state should not be None"
+ "Communication hook state should not be None",
)
grad = param.grad.data
if handle.uses_sharded_strategy:
@@ -3138,7 +3194,9 @@
num_pad = self.world_size * chunks[0].numel() - grad.numel()
input_flattened = F.pad(grad_flatten, [0, num_pad])
output = torch.zeros_like(chunks[0])
- self._communication_hook(self._communication_hook_state, input_flattened, output)
+ self._communication_hook(
+ self._communication_hook_state, input_flattened, output
+ )
self._cast_grad_to_param_dtype(output, param)
@@ -3153,13 +3211,13 @@
param._saved_grad_shard.shape == output.shape, # type: ignore[attr-defined]
"Shape mismatch when accumulating gradients: " # type: ignore[attr-defined]
f"existing grad shape={param._saved_grad_shard.shape} "
- f"new grad shape={output.shape}" # type: ignore[attr-defined]
+ f"new grad shape={output.shape}", # type: ignore[attr-defined]
)
p_assert(
param._saved_grad_shard.device == output.device, # type: ignore[attr-defined]
"Device mismatch when accumulating gradients: " # type: ignore[attr-defined]
f"existing grad device={param._saved_grad_shard.device} "
- f"new grad device={output.device}" # type: ignore[attr-defined]
+ f"new grad device={output.device}", # type: ignore[attr-defined]
)
param._saved_grad_shard += output # type: ignore[attr-defined]
else:
@@ -3167,7 +3225,9 @@
grad = param._saved_grad_shard # type: ignore[attr-defined]
else:
if self.sharding_strategy == ShardingStrategy.NO_SHARD:
- self._communication_hook(self._communication_hook_state, param.grad)
+ self._communication_hook(
+ self._communication_hook_state, param.grad
+ )
# For NO_SHARD keeping grads in the reduced precision, we
# can simply omit the cast as needed, we can't do this for
@@ -3221,12 +3281,9 @@
dtype cast happens in the hook instead.
"""
self._assert_state(TrainingState_.BACKWARD_POST)
- if (
- not self._low_precision_hook_enabled()
- and (
- self._mixed_precision_enabled_for_params()
- or self._mixed_precision_enabled_for_reduce()
- )
+ if not self._low_precision_hook_enabled() and (
+ self._mixed_precision_enabled_for_params()
+ or self._mixed_precision_enabled_for_reduce()
):
low_prec_grad_data = grad.data
grad.data = grad.data.to(dtype=param.dtype)
@@ -3236,9 +3293,8 @@
def _should_free_unsharded_flat_param(self, handle: FlatParamHandle):
return (
- (self._sync_gradients and handle.uses_sharded_strategy)
- or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
- )
+ self._sync_gradients and handle.uses_sharded_strategy
+ ) or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
def _queue_wait_for_post_backward(self) -> None:
"""
@@ -3247,7 +3303,7 @@
"""
p_assert(
self._is_root,
- "`_queue_wait_for_post_backward()` should be called on the root FSDP instance"
+ "`_queue_wait_for_post_backward()` should be called on the root FSDP instance",
)
if self._post_backward_callback_queued:
return
@@ -3295,18 +3351,21 @@
# TODO: This already-resharded check is brittle:
# https://github.com/pytorch/pytorch/issues/83956
already_resharded = (
- handle.flat_param.data_ptr() == handle.flat_param._local_shard.data_ptr()
+ handle.flat_param.data_ptr()
+ == handle.flat_param._local_shard.data_ptr()
)
if already_resharded:
continue
- free_unsharded_flat_params.append(self._should_free_unsharded_flat_param(handle))
+ free_unsharded_flat_params.append(
+ self._should_free_unsharded_flat_param(handle)
+ )
handles_to_reshard.append(handle)
self._reshard(handles_to_reshard, free_unsharded_flat_params)
except Exception as e:
p_assert(
False,
f"Got exception while resharding module {fsdp_module}: {str(e)}",
- raise_assertion_error=False
+ raise_assertion_error=False,
)
raise e
@@ -3318,7 +3377,7 @@
if hasattr(p, "_post_backward_hook_state"):
p_assert(
len(p._post_backward_hook_state) == 2, # type: ignore[attr-defined]
- "p._post_backward_hook_state fields are not valid."
+ "p._post_backward_hook_state fields are not valid.",
)
p._post_backward_hook_state[1].remove() # type: ignore[attr-defined]
delattr(p, "_post_backward_hook_state")
@@ -3331,8 +3390,8 @@
continue
handle.prepare_gradient_for_optim()
p_assert(
- hasattr(p, '_post_backward_called'),
- "Expected flag _post_backward_called to be set on param."
+ hasattr(p, "_post_backward_called"),
+ "Expected flag _post_backward_called to be set on param.",
)
# Reset _post_backward_called in preparation for the next iteration.
p._post_backward_called = False
@@ -3479,22 +3538,25 @@
norm_type = float(norm_type)
# Compute the local gradient norm (only including this rank's shard
# of the gradients)
- local_norm = _get_grad_norm(self.parameters(), norm_type).to(self.compute_device)
+ local_norm = _get_grad_norm(self.parameters(), norm_type).to(
+ self.compute_device
+ )
# Reconstruct the total gradient norm depending on the norm type
if norm_type == math.inf:
total_norm = local_norm
- dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
+ dist.all_reduce(
+ total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group
+ )
else:
- total_norm = local_norm ** norm_type
+ total_norm = local_norm**norm_type
dist.all_reduce(total_norm, group=self.process_group)
total_norm = total_norm ** (1.0 / norm_type)
if self.cpu_offload.offload_params:
total_norm = total_norm.cpu()
- clip_coef = (
- torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device)
- / (total_norm + 1e-6)
- )
+ clip_coef = torch.tensor(
+ max_norm, dtype=total_norm.dtype, device=total_norm.device
+ ) / (total_norm + 1e-6)
# Multiplying by the clamped coefficient is meaningless when it is
# equal to 1, but it avoids the host-device sync that would result from
# `if clip_coef < 1`
@@ -3537,9 +3599,12 @@
def full_optim_state_dict(
model: torch.nn.Module,
optim: torch.optim.Optimizer,
- optim_input: Optional[Union[
- List[Dict[str, Any]], Iterable[torch.nn.Parameter],
- ]] = None,
+ optim_input: Optional[
+ Union[
+ List[Dict[str, Any]],
+ Iterable[torch.nn.Parameter],
+ ]
+ ] = None,
rank0_only: bool = True,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
@@ -3592,7 +3657,8 @@
FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
- optim_input, optim,
+ optim_input,
+ optim,
)
return _optim_state_dict(
model=model,
@@ -3610,7 +3676,8 @@
optim: torch.optim.Optimizer,
optim_input: Optional[
Union[
- List[Dict[str, Any]], Iterable[torch.nn.Parameter],
+ List[Dict[str, Any]],
+ Iterable[torch.nn.Parameter],
]
] = None,
group: Optional[dist.ProcessGroup] = None,
@@ -3629,7 +3696,8 @@
FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
- optim_input, optim,
+ optim_input,
+ optim,
)
# TODO: The ultimate goal of the optimizer state APIs should be the same
# as state_dict/load_state_dict -- using one API to get optimizer states
@@ -3655,7 +3723,8 @@
model: torch.nn.Module,
optim_input: Optional[
Union[
- List[Dict[str, Any]], Iterable[torch.nn.Parameter],
+ List[Dict[str, Any]],
+ Iterable[torch.nn.Parameter],
]
] = None,
optim: Optional[torch.optim.Optimizer] = None,
@@ -3717,13 +3786,20 @@
FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
- optim_input, optim,
+ optim_input,
+ optim,
)
sharded_osd = _flatten_optim_state_dict(
- full_optim_state_dict, model, True,
+ full_optim_state_dict,
+ model,
+ True,
)
return _rekey_sharded_optim_state_dict(
- sharded_osd, model, optim, optim_input, using_optim_input,
+ sharded_osd,
+ model,
+ optim,
+ optim_input,
+ using_optim_input,
)
@staticmethod
@@ -3732,7 +3808,8 @@
model: torch.nn.Module,
optim_input: Optional[
Union[
- List[Dict[str, Any]], Iterable[torch.nn.Parameter],
+ List[Dict[str, Any]],
+ Iterable[torch.nn.Parameter],
]
] = None,
optim: Optional[torch.optim.Optimizer] = None,
@@ -3756,7 +3833,8 @@
FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
- optim_input, optim,
+ optim_input,
+ optim,
)
# TODO: The implementation is the same as ``shard_full_optim_state_dict``.
# See the TODO in ``shard_full_optim_state_dict`` for the future
@@ -3767,16 +3845,23 @@
shard_state=True,
)
return _rekey_sharded_optim_state_dict(
- flattened_osd, model, optim, optim_input, using_optim_input,
+ flattened_osd,
+ model,
+ optim,
+ optim_input,
+ using_optim_input,
)
@staticmethod
def scatter_full_optim_state_dict(
full_optim_state_dict: Optional[Dict[str, Any]],
model: torch.nn.Module,
- optim_input: Optional[Union[
- List[Dict[str, Any]], Iterable[torch.nn.Parameter],
- ]] = None,
+ optim_input: Optional[
+ Union[
+ List[Dict[str, Any]],
+ Iterable[torch.nn.Parameter],
+ ]
+ ] = None,
optim: Optional[torch.optim.Optimizer] = None,
group: Optional[Any] = None,
) -> Dict[str, Any]:
@@ -3838,7 +3923,8 @@
FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
- optim_input, optim,
+ optim_input,
+ optim,
)
# Try to use the passed-in process group, the model's process group,
# or the default process group (i.e. `None`) in that priority order
@@ -3848,8 +3934,9 @@
world_size = dist.get_world_size(group)
# Check for a valid broadcast device, preferring GPU when available
using_nccl = dist.distributed_c10d._check_for_nccl_backend(group)
- broadcast_device = torch.device("cuda") if torch.cuda.is_available() \
- else torch.device("cpu")
+ broadcast_device = (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
if using_nccl and not torch.cuda.is_available():
raise RuntimeError("NCCL requires a GPU for collectives")
# Flatten the optimizer state dict and construct a copy with the
@@ -3867,18 +3954,28 @@
# Broadcast the optim state dict without positive-dimension tensor
# state and the FSDP parameter IDs from rank 0 to all ranks
processed_osd = _broadcast_processed_optim_state_dict(
- processed_osd if rank == 0 else None, rank, group,
+ processed_osd if rank == 0 else None,
+ rank,
+ group,
)
# Broadcast positive-dimension tensor state (both sharded tensors for
# FSDP parameters and unsharded tensors for non-FSDP parameters)
sharded_osd = _broadcast_pos_dim_tensor_states(
- processed_osd, flat_osd if rank == 0 else None, rank, world_size,
- group, broadcast_device,
+ processed_osd,
+ flat_osd if rank == 0 else None,
+ rank,
+ world_size,
+ group,
+ broadcast_device,
)
# Rekey the optimizer state dict to use parameter IDs according to this
# rank's `optim`
sharded_osd = _rekey_sharded_optim_state_dict(
- sharded_osd, model, optim, optim_input, using_optim_input,
+ sharded_osd,
+ model,
+ optim,
+ optim_input,
+ using_optim_input,
)
return sharded_osd
@@ -3887,9 +3984,12 @@
optim_state_dict: Dict[str, Any],
optim_state_key_type: OptimStateKeyType,
model: torch.nn.Module,
- optim_input: Optional[Union[
- List[Dict[str, Any]], Iterable[torch.nn.Parameter],
- ]] = None,
+ optim_input: Optional[
+ Union[
+ List[Dict[str, Any]],
+ Iterable[torch.nn.Parameter],
+ ]
+ ] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Dict[str, Any]:
"""
@@ -3926,30 +4026,30 @@
"""
FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
- optim_input, optim,
+ optim_input,
+ optim,
)
assert optim_state_key_type in (
- OptimStateKeyType.PARAM_NAME, OptimStateKeyType.PARAM_ID,
+ OptimStateKeyType.PARAM_NAME,
+ OptimStateKeyType.PARAM_ID,
)
osd = optim_state_dict # alias
# Validate that the existing parameter keys are uniformly typed
- uses_param_name_mask = [
- type(param_key) is str for param_key in osd["state"]
- ]
- uses_param_id_mask = [
- type(param_key) is int for param_key in osd["state"]
- ]
- if (
- (any(uses_param_name_mask) and not all(uses_param_name_mask))
- or (any(uses_param_id_mask) and not all(uses_param_id_mask))
+ uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]]
+ uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]]
+ if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or (
+ any(uses_param_id_mask) and not all(uses_param_id_mask)
):
error_msg = f"Invalid parameter keys: {osd['state'].keys()}"
raise ValueError(error_msg)
# Return directly if the existing key type matches the target key type
- if (optim_state_key_type == OptimStateKeyType.PARAM_NAME and
- all(uses_param_name_mask)) or \
- (optim_state_key_type == OptimStateKeyType.PARAM_ID and
- all(uses_param_id_mask)):
+ if (
+ optim_state_key_type == OptimStateKeyType.PARAM_NAME
+ and all(uses_param_name_mask)
+ ) or (
+ optim_state_key_type == OptimStateKeyType.PARAM_ID
+ and all(uses_param_id_mask)
+ ):
return osd
# Otherwise, actually perform the re-keying
new_osd = {}
@@ -3969,10 +4069,12 @@
}
new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
for param_group in new_osd["param_groups"]:
- param_group["params"] = sorted([
- param_id_to_param_name[param_id]
- for param_id in param_group["params"]
- ])
+ param_group["params"] = sorted(
+ [
+ param_id_to_param_name[param_id]
+ for param_id in param_group["params"]
+ ]
+ )
return new_osd
elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID
param_name_to_param = _get_param_name_to_param(model)
@@ -3994,10 +4096,12 @@
}
new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
for param_group in new_osd["param_groups"]:
- param_group["params"] = sorted([
- param_name_to_param_id[param_name]
- for param_name in param_group["params"]
- ])
+ param_group["params"] = sorted(
+ [
+ param_name_to_param_id[param_name]
+ for param_name in param_group["params"]
+ ]
+ )
return new_osd
return new_osd # should never reach here
@@ -4056,12 +4160,17 @@
"""
if not self.check_is_root():
- raise AssertionError("register_comm_hook can only be called on a root instance.")
+ raise AssertionError(
+ "register_comm_hook can only be called on a root instance."
+ )
for submodule in self.fsdp_modules(self):
- assert not submodule._hook_registered, "communication hook can be only registered once"
+ assert (
+ not submodule._hook_registered
+ ), "communication hook can be only registered once"
submodule._hook_registered = True
- assert submodule._communication_hook == self._get_default_comm_hook(),\
- f"communication hook should be default, but it is {submodule._communication_hook.__name__} instead"
+ assert (
+ submodule._communication_hook == self._get_default_comm_hook()
+ ), f"communication hook should be default, but it is {submodule._communication_hook.__name__} instead"
submodule._communication_hook_state = state
submodule._communication_hook = hook
@@ -4073,10 +4182,7 @@
assert (
auto_wrap_policy.tracing_config is None
), "tracing_config should be None when torch.fx is not enabled"
- elif isinstance(
- auto_wrap_policy.tracing_config,
- TracingConfig
- ):
+ elif isinstance(auto_wrap_policy.tracing_config, TracingConfig):
tracer = auto_wrap_policy.tracing_config.tracer
execution_info = _init_execution_info(module)
@@ -4110,8 +4216,7 @@
# A list that stores the flatten parameters and its name based on the parameter execution order
self._fsdp_params_exec_order: List[FlatParameter] = []
if _TORCH_FX_AVAIL and isinstance(
- auto_wrap_policy.tracing_config,
- TracingConfig
+ auto_wrap_policy.tracing_config, TracingConfig
):
# Initialize a dict that maps each module to its parent FSDP wrap
module_to_fsdp: Dict[nn.Module, FullyShardedDataParallel] = dict()
@@ -4137,8 +4242,7 @@
def _use_param_exec_order_policy(self) -> bool:
return (
- hasattr(self, "_param_exec_order_policy")
- and self._param_exec_order_policy
+ hasattr(self, "_param_exec_order_policy") and self._param_exec_order_policy
)
def _is_param_exec_order_prep_stage(self) -> bool:
@@ -4148,8 +4252,8 @@
)
if not is_prep_stage:
for p in self.parameters():
- assert (
- not hasattr(p, "_params_exec_order_hook_handle")
+ assert not hasattr(
+ p, "_params_exec_order_hook_handle"
), "When not in execution order prep stage, all _params_exec_order_hook_handle should be removed."
return is_prep_stage
@@ -4168,7 +4272,9 @@
grads = [param.grad for param in params_with_grad]
grad_dtypes = set(grad.dtype for grad in grads)
if len(grad_dtypes) != 1:
- raise ValueError(f"Requires uniform dtype across all gradients but got {grad_dtypes}")
+ raise ValueError(
+ f"Requires uniform dtype across all gradients but got {grad_dtypes}"
+ )
# Compute the gradient norm in FP32, where we treat the gradients as a
# single vector
grad_norm = torch.linalg.vector_norm(
@@ -4206,15 +4312,14 @@
in the module walk order; if ``False``, then includes all of the
unflattened parameter names.
"""
+
def module_fn(module, prefix, param_to_unflat_param_names):
for param_name, param in module.named_parameters(recurse=False):
module_prefixed_param_names = (
- param._fqns if type(param) is FlatParameter
- else [param_name]
+ param._fqns if type(param) is FlatParameter else [param_name]
) # prefixed from `module`
fully_prefixed_param_names = [
- clean_tensor_name(prefix + name)
- for name in module_prefixed_param_names
+ clean_tensor_name(prefix + name) for name in module_prefixed_param_names
] # fully prefixed from the top level including `prefix`
# If this parameter has already been visited, then it is a
# shared parameter; then, only take the first parameter name
@@ -4229,7 +4334,10 @@
param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
return _apply_to_modules(
- model, module_fn, return_fn, param_to_unflat_param_names,
+ model,
+ module_fn,
+ return_fn,
+ param_to_unflat_param_names,
)
@@ -4250,16 +4358,16 @@
"""
param_to_param_names = _get_param_to_unflat_param_names(model)
for param_names in param_to_param_names.values():
- assert len(param_names) > 0, "`_get_param_to_unflat_param_names()` " \
- "should not construct empty lists"
+ assert len(param_names) > 0, (
+ "`_get_param_to_unflat_param_names()` " "should not construct empty lists"
+ )
if len(param_names) > 1:
raise RuntimeError(
"Each parameter should only map to one parameter name but got "
f"{len(param_names)}: {param_names}"
)
param_to_param_name = {
- param: param_names[0]
- for param, param_names in param_to_param_names.items()
+ param: param_names[0] for param, param_names in param_to_param_names.items()
}
return param_to_param_name
diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py
index 27ba44e..86dbfd7 100644
--- a/torch/distributed/fsdp/sharded_grad_scaler.py
+++ b/torch/distributed/fsdp/sharded_grad_scaler.py
@@ -1,12 +1,12 @@
-from collections import abc, defaultdict
import logging
+from collections import abc, defaultdict
from typing import Dict, List, Optional, Union
import torch
-from torch.cuda import FloatTensor # type: ignore[attr-defined]
-from torch.cuda.amp.grad_scaler import GradScaler, OptState, _MultiDeviceReplicator
-from torch.distributed.distributed_c10d import ProcessGroup
import torch.distributed as dist
+from torch.cuda import FloatTensor # type: ignore[attr-defined]
+from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
+from torch.distributed.distributed_c10d import ProcessGroup
from torch.optim.sgd import SGD
@@ -23,6 +23,7 @@
Lazily serves tensor to request device. This class extends
_MultiDeviceReplicator to allow support for "cpu" as a device.
"""
+
def __init__(self, master_tensor: torch.Tensor) -> None:
assert _is_supported_device(master_tensor)
self.master = master_tensor
@@ -77,9 +78,10 @@
process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
process group for sharding
"""
+
def __init__(
self,
- init_scale: float = 2.0 ** 16,
+ init_scale: float = 2.0**16,
backoff_factor: float = 0.5,
growth_factor: float = 2.0,
growth_interval: int = 2000,
@@ -97,7 +99,9 @@
self.process_group = process_group
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
- def scale(self, outputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
+ def scale(
+ self, outputs: Union[torch.Tensor, List[torch.Tensor]]
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
if not self._enabled:
return outputs
@@ -106,7 +110,9 @@
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
- scaled_output = outputs * self._scale.to(device=outputs.device, non_blocking=True)
+ scaled_output = outputs * self._scale.to(
+ device=outputs.device, non_blocking=True
+ )
# Here we ensure the return dtype is the same as the outputs dtype.
# For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
# format (fp16, bf16) and so the scaled loss should be of the same dtype.
@@ -114,7 +120,9 @@
stash: List[_GeneralMultiDeviceReplicator] = []
- def apply_scale(val: Union[torch.Tensor, abc.Iterable]) -> Union[torch.Tensor, abc.Iterable]:
+ def apply_scale(
+ val: Union[torch.Tensor, abc.Iterable]
+ ) -> Union[torch.Tensor, abc.Iterable]:
if isinstance(val, torch.Tensor):
assert _is_supported_device(val)
if len(stash) == 0:
@@ -150,20 +158,30 @@
for grad in grads:
for tensor in grad:
if tensor.device != expected_device:
- logging.error("tensor device is %s and expected device is %s" % (tensor.device, expected_device))
+ logging.error(
+ "tensor device is %s and expected device is %s"
+ % (tensor.device, expected_device)
+ )
raise ValueError("Gradients must be on the same device.")
# check for non_overlapping_and_dense doesn't exist in the python world
# as remarked here https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/AmpKernels.cu#L108
# we assume tensor is not MTA(multi tensor apply) safe. iterate through each item regardless of dtype
- if torch.isinf(tensor).any().item() is True or torch.isnan(tensor).any().item() is True:
+ if (
+ torch.isinf(tensor).any().item() is True
+ or torch.isnan(tensor).any().item() is True
+ ):
found_inf.data = torch.tensor([1.0])
break
else:
tensor.data *= inv_scale.item()
def _unscale_grads_(
- self, optimizer: SGD, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool = True
+ self,
+ optimizer: SGD,
+ inv_scale: torch.Tensor,
+ found_inf: torch.Tensor,
+ allow_fp16: bool = True,
) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
@@ -195,7 +213,9 @@
else:
to_unscale = param.grad
- per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)
+ per_device_and_dtype_grads[to_unscale.device][
+ to_unscale.dtype
+ ].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
@@ -222,16 +242,22 @@
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
- raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
+ raise RuntimeError(
+ "unscale_() has already been called on this optimizer since the last update()."
+ )
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
- found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
+ found_inf = torch.full(
+ (1,), 0.0, dtype=torch.float32, device=self._scale.device
+ )
- optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, True)
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
+ optimizer, inv_scale, found_inf, True
+ )
optimizer_state["stage"] = OptState.UNSCALED
# Synchronize the detected inf across the ranks
@@ -241,10 +267,18 @@
for v in optimizer_state["found_inf_per_device"].values():
if v.device.type == "cpu":
v_on_cuda = v.cuda()
- future_handles.append(dist.all_reduce(v_on_cuda, async_op=True, group=self.process_group).get_future())
+ future_handles.append(
+ dist.all_reduce(
+ v_on_cuda, async_op=True, group=self.process_group
+ ).get_future()
+ )
v.copy_(v_on_cuda.cpu())
else:
- future_handles.append(dist.all_reduce(v, async_op=True, group=self.process_group).get_future())
+ future_handles.append(
+ dist.all_reduce(
+ v, async_op=True, group=self.process_group
+ ).get_future()
+ )
# Make sure that the calls are done before moving out.
if future_handles:
diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py
index 8013da8..c529bcd 100644
--- a/torch/distributed/fsdp/wrap.py
+++ b/torch/distributed/fsdp/wrap.py
@@ -5,22 +5,11 @@
import contextlib
from dataclasses import dataclass
-from typing import (
- Any,
- Callable,
- Dict,
- Generator,
- Optional,
- Set,
- Tuple,
- Type,
- cast,
-)
+from typing import Any, Callable, cast, Dict, Generator, Optional, Set, Tuple, Type
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
-
__all__ = [
"always_wrap_policy",
"lambda_auto_wrap_policy",
@@ -41,11 +30,9 @@
"""
return True
+
def lambda_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- unwrapped_params: int,
- lambda_fn: Callable
+ module: nn.Module, recurse: bool, unwrapped_params: int, lambda_fn: Callable
) -> bool:
"""
A convenient auto wrap policy to wrap submodules based on an arbitrary user
@@ -78,6 +65,7 @@
# if not recursing, decide whether we should wrap for the leaf node or reminder
return lambda_fn(module)
+
def transformer_auto_wrap_policy(
module: nn.Module,
recurse: bool,
@@ -121,6 +109,7 @@
# if not recursing, decide whether we should wrap for the leaf node or reminder
return isinstance(module, tuple(transformer_layer_cls))
+
def _wrap_batchnorm_individually(
module: nn.Module,
recurse: bool,
@@ -138,6 +127,7 @@
# BN layer or not.
return isinstance(module, _BatchNorm)
+
def _or_policy(
module: nn.Module,
recurse: bool,
@@ -148,9 +138,7 @@
A policy that wraps ``module`` if any policy in the passed in iterable of
``policies`` returns ``True``.
"""
- return any(
- policy(module, recurse, unwrapped_params) for policy in policies
- )
+ return any(policy(module, recurse, unwrapped_params) for policy in policies)
def size_based_auto_wrap_policy(
@@ -333,13 +321,14 @@
``full``, ``full_like``, ``eye``, ``empty``, ``tensor``). For those cases,
users can set ``tracing_config = None`` to disable symbolic tracing.
"""
+
init_policy: Callable = always_wrap_policy
tracing_config: Any = None
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
assert wrapper_cls is not None
- if hasattr(module, '_wrap_overrides'):
+ if hasattr(module, "_wrap_overrides"):
# If module has a _wrap_overrides attribute, we force overriding the
# FSDP config with these attributes for this module. Currently this
# is only used to disable mixed precision for BatchNorm when
@@ -357,7 +346,7 @@
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
only_wrap_children: bool = False,
- **kwargs: Any
+ **kwargs: Any,
) -> Tuple[nn.Module, int]:
"""
Automatically wrap child modules of *module* that meet the given
@@ -389,9 +378,7 @@
pass
# We count all params, assuming none of them are already wrapped.
- num_params = sum(
- p.numel() for p in module.parameters() if p not in ignored_params
- )
+ num_params = sum(p.numel() for p in module.parameters() if p not in ignored_params)
assert auto_wrap_policy is not None
if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):