[FSDP] ufmt /fsdp (#87811)

This applies `ufmt` to all of the FSDP files in the `torch/distributed/fsdp/` directory.

**Test Plan**
CI

**Notes**
For VSCode users,
- Install `ufmt`: https://pypi.org/project/ufmt/
- Install VSCode `ufmt` extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
- Include in `settings.json`:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87811
Approved by: https://github.com/rohan-varma, https://github.com/fegin
diff --git a/torch/distributed/fsdp/_fsdp_extensions.py b/torch/distributed/fsdp/_fsdp_extensions.py
index abe0d90..1f087f4 100644
--- a/torch/distributed/fsdp/_fsdp_extensions.py
+++ b/torch/distributed/fsdp/_fsdp_extensions.py
@@ -5,7 +5,6 @@
 import torch.distributed as dist
 from torch.distributed._shard.sharded_tensor.api import ShardedTensor
 from torch.distributed._shard.sharded_tensor.shard import Shard
-
 from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
 
 
diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py
index a5e1ab6..f87f871 100644
--- a/torch/distributed/fsdp/_optim_utils.py
+++ b/torch/distributed/fsdp/_optim_utils.py
@@ -3,6 +3,7 @@
 import functools
 from typing import (
     Any,
+    cast,
     Dict,
     Iterable,
     Iterator,
@@ -12,18 +13,18 @@
     Sequence,
     Tuple,
     Union,
-    cast,
 )
 
 import torch
 import torch.distributed as dist
+
 # Import the entire FSDP file to avoid circular imports
 import torch.distributed.fsdp.fully_sharded_data_parallel as FSDP
 import torch.nn as nn
 from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
 from torch.distributed.fsdp._shard_utils import _gather_state_dict
 from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle
-from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
 
 
 def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
@@ -298,9 +299,9 @@
     unflat_osd_state = unflat_osd["state"]
     for param, unflat_param_names in param_to_unflat_param_names.items():
         if isinstance(param, FlatParameter):  # flatten FSDP parameters' states
-            assert param in flat_param_to_fsdp_module, (
-                f"Check the `flat_param_to_fsdp_module` construction\nparam: {param}"
-            )
+            assert (
+                param in flat_param_to_fsdp_module
+            ), f"Check the `flat_param_to_fsdp_module` construction\nparam: {param}"
             fsdp_module = flat_param_to_fsdp_module[param]
             flat_state = _flatten_optim_state(
                 unflat_osd_state,
diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py
index b0382b4..0cc9dd6 100644
--- a/torch/distributed/fsdp/_shard_utils.py
+++ b/torch/distributed/fsdp/_shard_utils.py
@@ -250,10 +250,8 @@
             requires_grad=False,
             memory_format=torch.contiguous_format,
             pin_memory=tensor.is_pinned(),
-        )
+        ),
     )
     return ShardedTensor._init_from_local_shards_and_global_metadata(
-        local_shards,
-        sharded_tensor_metadata=sharded_tensor_metadata,
-        process_group=pg
+        local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
     )
diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py
index ed4b8f2..90083ef 100644
--- a/torch/distributed/fsdp/_state_dict_utils.py
+++ b/torch/distributed/fsdp/_state_dict_utils.py
@@ -6,25 +6,24 @@
 import torch
 import torch.distributed as dist
 import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
+
 # Import the entire FSDP file to avoid circular imports
 import torch.distributed.fsdp.fully_sharded_data_parallel as FSDP
 import torch.nn as nn
 import torch.nn.functional as F
-
 from torch.distributed._shard.sharded_tensor import (
+    init_from_local_shards,
     Shard,
     ShardedTensor,
-    init_from_local_shards,
 )
-from torch.distributed.utils import (
-    _replace_by_prefix,
-)
+from torch.distributed.utils import _replace_by_prefix
 
-from ._fsdp_extensions import _ext_chunk_tensor, _ext_pre_load_state_dict_transform
-from ._fsdp_extensions import _extensions as _user_extensions
-from .flat_param import (
-    FlatParamHandle,
+from ._fsdp_extensions import (
+    _ext_chunk_tensor,
+    _ext_pre_load_state_dict_transform,
+    _extensions as _user_extensions,
 )
+from .flat_param import FlatParamHandle
 
 
 def _full_post_state_dict_hook(
@@ -53,16 +52,12 @@
     # exiting `summon_full_params()` via the parameter shape. However, for
     # `NO_SHARD`, we cannot tell from the shape, so we do not return early.
     if (
-        (
-            not module._use_orig_params
-            and FSDP.FLAT_PARAM in module.module._parameters
-        )
-        or (
-            module._use_orig_params
-            and module._handles
-            and module._handles[0].uses_sharded_strategy
-            and module._handles[0].is_sharded(module._handles[0].flat_param)
-        )
+        not module._use_orig_params and FSDP.FLAT_PARAM in module.module._parameters
+    ) or (
+        module._use_orig_params
+        and module._handles
+        and module._handles[0].uses_sharded_strategy
+        and module._handles[0].is_sharded(module._handles[0].flat_param)
     ):
         return state_dict
 
@@ -79,7 +74,7 @@
         # do not have prefix considered as they are not computed in `state_dict`
         # call.
         if clean_key.startswith(clean_prefix):
-            clean_key = clean_key[len(clean_prefix):]
+            clean_key = clean_key[len(clean_prefix) :]
 
         # Clone non-ignored parameters before exiting the
         # `_summon_full_params()` context
@@ -88,8 +83,9 @@
             f"only has {state_dict.keys()}. prefix={prefix}, "
             f"module_name={module_name} param_name={param_name} rank={module.rank}."
         )
-        if clean_key not in module._ignored_param_names and \
-                not getattr(state_dict[fqn], "_has_been_cloned", False):
+        if clean_key not in module._ignored_param_names and not getattr(
+            state_dict[fqn], "_has_been_cloned", False
+        ):
             try:
                 state_dict[fqn] = state_dict[fqn].clone().detach()
                 state_dict[fqn]._has_been_cloned = True  # type: ignore[attr-defined]
@@ -129,11 +125,9 @@
 ) -> None:
     # We do not expect to be calling pre-hooks twice without post-hook
     # call in between.
-    assert getattr(module, '_full_param_ctx', None) is None
+    assert getattr(module, "_full_param_ctx", None) is None
     # Note that it needs writeback=True to persist.
-    module._full_param_ctx = module._summon_full_params(
-        recurse=False, writeback=True
-    )
+    module._full_param_ctx = module._summon_full_params(recurse=False, writeback=True)
     module._full_param_ctx.__enter__()
     _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP.FSDP_PREFIX}")
 
@@ -141,7 +135,7 @@
 def _full_post_load_state_dict_hook(module, *args, **kwargs) -> None:
     # We should exit summon_full_params context.
     module._assert_state([FSDP.TrainingState_.SUMMON_FULL_PARAMS])
-    assert getattr(module, '_full_param_ctx', None) is not None
+    assert getattr(module, "_full_param_ctx", None) is not None
     module._full_param_ctx.__exit__(None, None, None)
     module._full_param_ctx = None
 
@@ -189,7 +183,9 @@
 
 
 def _local_pre_load_state_dict_hook(
-    module, state_dict: Dict[str, Any], prefix: str,
+    module,
+    state_dict: Dict[str, Any],
+    prefix: str,
 ) -> None:
     """
     This hook finds the local flat_param for this FSDP module from the
@@ -253,7 +249,7 @@
                 rank=module.rank,
                 world_size=module.world_size,
                 num_devices_per_node=torch.cuda.device_count(),
-                pg=module.process_group
+                pg=module.process_group,
             )
             if module._state_dict_config.offload_to_cpu:
                 sharded_tensor = sharded_tensor.cpu()
@@ -271,7 +267,9 @@
 
 
 def _sharded_pre_load_state_dict_hook(
-    module, state_dict: Dict[str, Any], prefix: str,
+    module,
+    state_dict: Dict[str, Any],
+    prefix: str,
 ) -> None:
     """
     The hook combines the unflattened, sharded parameters (ShardedTensor) to
@@ -331,7 +329,9 @@
 
     # Get the chunk from the loaded flat_param for the local rank.
     loaded_flat_tensor, num_to_pad = FlatParamHandle._get_shard(
-        loaded_flat_param, module.rank, module.world_size,
+        loaded_flat_param,
+        module.rank,
+        module.world_size,
     )
     loaded_flat_tensor.to(flat_param.device)
     assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), (
@@ -377,10 +377,7 @@
     # back to their mixed precision type. This is because buffers are cast
     # during lazy_init() and stay at their mixed precision type before/after
     # forward/backward. As a result state_dict() should maintain this.
-    if (
-        fsdp_module._is_root
-        and fsdp_module._mixed_precision_enabled_for_buffers()
-    ):
+    if fsdp_module._is_root and fsdp_module._mixed_precision_enabled_for_buffers():
         fsdp_module._cast_buffers(recurse=True)
     return processed_state_dict
 
diff --git a/torch/distributed/fsdp/_symbolic_trace.py b/torch/distributed/fsdp/_symbolic_trace.py
index 026595f..f6fe5e4 100644
--- a/torch/distributed/fsdp/_symbolic_trace.py
+++ b/torch/distributed/fsdp/_symbolic_trace.py
@@ -5,7 +5,6 @@
 
 import torch
 
-
 __all__ = ["TracingConfig"]
 
 
@@ -140,13 +139,18 @@
         if args is not None:
             named_params: List[Tuple[str, torch.nn.Parameter]] = []
             for arg in args:
-                if isinstance(arg, torch.fx.Proxy) and arg.node.target in prefixed_param_name_to_param:
+                if (
+                    isinstance(arg, torch.fx.Proxy)
+                    and arg.node.target in prefixed_param_name_to_param
+                ):
                     param = prefixed_param_name_to_param[arg.node.target]
                     named_params.append((arg.node.target, param))
                     if param not in set(execution_info.param_exec_order):
                         execution_info.param_exec_order.append(param)
             if named_params:
-                execution_info.module_to_execution_infos[module].append((module, named_params))
+                execution_info.module_to_execution_infos[module].append(
+                    (module, named_params)
+                )
     elif kind == "call_module":
         named_params = list(module.named_parameters())
         if named_params:
@@ -234,7 +238,10 @@
     )
     prefixed_param_name_to_param = dict(root_module.named_parameters())
     tracer.create_proxy = functools.partial(
-        _patched_create_proxy, original_create_proxy, execution_info, prefixed_param_name_to_param
+        _patched_create_proxy,
+        original_create_proxy,
+        execution_info,
+        prefixed_param_name_to_param,
     )
     try:
         yield
diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py
index bd37ce5..eb72042 100644
--- a/torch/distributed/fsdp/_utils.py
+++ b/torch/distributed/fsdp/_utils.py
@@ -10,14 +10,11 @@
 )
 from torch.nn.utils.rnn import PackedSequence
 
-
 FSDP_FLATTENED = "_fsdp_flattened"
 
 
 def _contains_batchnorm(module):
-    return any(
-        isinstance(mod, _BatchNorm) for mod in module.modules()
-    )
+    return any(isinstance(mod, _BatchNorm) for mod in module.modules())
 
 
 def _override_batchnorm_mixed_precision(module):
@@ -27,11 +24,14 @@
 
 
 def _apply_to_tensors(
-    fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
+    fn: Callable,
+    container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
 ) -> Any:
     """Recursively apply to all tensor in different kinds of container types."""
 
-    def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]) -> Any:
+    def apply(
+        x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
+    ) -> Any:
         if torch.is_tensor(x):
             return fn(x)
         elif hasattr(x, "__dataclass_fields__"):
@@ -75,6 +75,7 @@
     module prefix name (e.g. "module.submodule." just like in model state dict)
     and makes that available to ``module_fn``.
     """
+
     def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
         # Call the module function before recursing over children (pre-order)
         module_fn(module, prefix, *args, **kwargs)
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 266dc80..3e4eca0 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -33,7 +33,6 @@
     p_assert,
 )
 
-
 __all__ = [
     "FlatParameter",
     "FlatParamHandle",
@@ -1507,7 +1506,8 @@
                 # memory and owns the gradient storage, so it will never
                 # require gradient writeback.
                 flat_param_grad = (
-                    flat_param.grad if self.uses_sharded_strategy or not self._config.offload_params
+                    flat_param.grad
+                    if self.uses_sharded_strategy or not self._config.offload_params
                     else flat_param._cpu_grad  # type: ignore[attr-defined]
                 )
                 needs_grad_writeback = flat_param_grad is None or not _same_storage(
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 5fb2e5c..8cd1847 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -8,10 +8,11 @@
 import warnings
 from contextlib import contextmanager
 from dataclasses import dataclass
-from enum import Enum, auto
+from enum import auto, Enum
 from typing import (
     Any,
     Callable,
+    cast,
     Deque,
     Dict,
     Generator,
@@ -22,7 +23,6 @@
     Set,
     Tuple,
     Union,
-    cast,
 )
 
 import torch
@@ -35,15 +35,10 @@
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
     _CHECKPOINT_PREFIX,
 )
-from torch.distributed.algorithms._comm_hooks import (
-    LOW_PRECISION_HOOKS,
-    default_hooks,
-)
+from torch.distributed.algorithms._comm_hooks import default_hooks, LOW_PRECISION_HOOKS
 from torch.distributed.distributed_c10d import _get_default_group
-from torch.distributed.utils import (
-    _sync_params_and_buffers,
-    _to_kwargs,
-)
+from torch.distributed.utils import _sync_params_and_buffers, _to_kwargs
+
 from ._optim_utils import (
     _broadcast_pos_dim_tensor_states,
     _broadcast_processed_optim_state_dict,
@@ -57,9 +52,9 @@
     _rekey_sharded_optim_state_dict,
 )
 from ._state_dict_utils import (
+    _post_load_state_dict_hook,
     _post_state_dict_hook,
     _pre_load_state_dict_hook,
-    _post_load_state_dict_hook,
 )
 from ._utils import (
     _apply_to_modules,
@@ -78,10 +73,10 @@
     HandleTrainingState,
 )
 from .wrap import (
-    ParamExecOrderWrapPolicy,
     _or_policy,
     _recursive_wrap,
     _wrap_batchnorm_individually,
+    ParamExecOrderWrapPolicy,
 )
 
 _TORCHDISTX_AVAIL = True
@@ -94,18 +89,23 @@
 if not hasattr(torch, "fx"):
     _TORCH_FX_AVAIL = False
 if _TORCH_FX_AVAIL:
-    from ._symbolic_trace import (
-        TracingConfig,
-        _init_execution_info,
-        _patch_tracer,
-    )
+    from ._symbolic_trace import _init_execution_info, _patch_tracer, TracingConfig
 
 
 __all__ = [
-    "FullyShardedDataParallel", "ShardingStrategy", "MixedPrecision",
-    "CPUOffload", "BackwardPrefetch", "StateDictType", "StateDictConfig",
-    "FullStateDictConfig", "LocalStateDictConfig", "ShardedStateDictConfig",
-    "OptimStateKeyType", "TrainingState_", "clean_tensor_name",
+    "FullyShardedDataParallel",
+    "ShardingStrategy",
+    "MixedPrecision",
+    "CPUOffload",
+    "BackwardPrefetch",
+    "StateDictType",
+    "StateDictConfig",
+    "FullStateDictConfig",
+    "LocalStateDictConfig",
+    "ShardedStateDictConfig",
+    "OptimStateKeyType",
+    "TrainingState_",
+    "clean_tensor_name",
 ]
 
 
@@ -148,6 +148,7 @@
                                   ``NO_SHARD`` inter-node.
 
     """
+
     FULL_SHARD = auto()
     SHARD_GRAD_OP = auto()
     NO_SHARD = auto()
@@ -197,6 +198,7 @@
         would occur in the `param_dtype` precision, if given, otherwise, in the
         original parameter precision.
     """
+
     # maintain a tensor of this dtype that the fp32 param shard will be cast to.
     # Will control the precision of model params, inputs, and thus compute as
     # well.
@@ -309,6 +311,7 @@
     order to configure settings for the particular type of ``state_dict``
     implementation FSDP will use.
     """
+
     offload_to_cpu: bool = False
 
 
@@ -340,6 +343,7 @@
         >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
         >>> # After this point, all ranks have FSDP model with loaded checkpoint.
     """
+
     rank0_only: bool = False
 
 
@@ -366,9 +370,10 @@
 
 class _ExecOrderWarnStatus(Enum):
     """Used internally for execution order validation."""
-    NONE = auto()     # no deviation yet
+
+    NONE = auto()  # no deviation yet
     WARNING = auto()  # deviated this iteration; currently issuing warnings
-    WARNED = auto()   # deviated in a previous iteration
+    WARNED = auto()  # deviated in a previous iteration
 
 
 class _ExecOrderData:
@@ -403,9 +408,10 @@
         self._forward_prefetch_limit = forward_prefetch_limit
 
         # Data structures for execution order validation
-        self._checking_order: bool = (
-            debug_level in [dist.DebugLevel.INFO, dist.DebugLevel.DETAIL]
-        )
+        self._checking_order: bool = debug_level in [
+            dist.DebugLevel.INFO,
+            dist.DebugLevel.DETAIL,
+        ]
         self.process_group: Optional[dist.ProcessGroup] = None
         self.world_size: Optional[int] = None
         self.all_handles: List[FlatParamHandle] = []
@@ -454,7 +460,9 @@
         prefetch given the current handles key. If there are no valid handles
         keys to prefetch, then this returns an empty :class:`list`.
         """
-        current_index = self.handles_to_post_forward_order_index.get(current_handles_key, None)
+        current_index = self.handles_to_post_forward_order_index.get(
+            current_handles_key, None
+        )
         if current_index is None:
             return None
         target_index = current_index - 1
@@ -462,9 +470,7 @@
         for _ in range(self._backward_prefetch_limit):
             if target_index < 0:
                 break
-            target_handles_keys.append(
-                self.handles_post_forward_order[target_index]
-            )
+            target_handles_keys.append(self.handles_post_forward_order[target_index])
             target_index -= 1
         return target_handles_keys
 
@@ -477,7 +483,9 @@
         prefetch given the current handles key. If there are no valid handles
         keys to prefetch, then this returns an empty :class:`list`.
         """
-        current_index = self.handles_to_pre_forward_order_index.get(current_handles_key, None)
+        current_index = self.handles_to_pre_forward_order_index.get(
+            current_handles_key, None
+        )
         if current_index is None:
             return None
         target_index = current_index + 1
@@ -485,9 +493,7 @@
         for _ in range(self._forward_prefetch_limit):
             if target_index >= len(self.handles_pre_forward_order):
                 break
-            target_handles_keys.append(
-                self.handles_pre_forward_order[target_index]
-            )
+            target_handles_keys.append(self.handles_pre_forward_order[target_index])
             target_index += 1
         return target_handles_keys
 
@@ -511,7 +517,9 @@
         self.handles_to_post_forward_order_index[handles_key] = index
         self.handles_post_forward_order.append(handles_key)
 
-    def record_pre_forward(self, handles: List[FlatParamHandle], is_training: bool) -> None:
+    def record_pre_forward(
+        self, handles: List[FlatParamHandle], is_training: bool
+    ) -> None:
         """
         Records ``handles`` in the pre-forward order, where ``handles`` should
         be a group of handles used in the same module's forward. If ``handles``
@@ -597,7 +605,7 @@
                     (
                         rank,
                         world_indices[
-                            rank * num_valid_indices: (rank + 1) * num_valid_indices
+                            rank * num_valid_indices : (rank + 1) * num_valid_indices
                         ],
                     )
                     for rank in range(self.world_size)
@@ -683,7 +691,9 @@
                 continue
             handle = self.all_handles[index]
             flat_param = handle.flat_param
-            prefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param])
+            prefixed_param_names.append(
+                self.flat_param_to_prefixed_param_names[flat_param]
+            )
         return prefixed_param_names
 
     def _get_names_from_handles(
@@ -700,7 +710,9 @@
             flat_param = handle.flat_param
             if flat_param not in self.flat_param_to_prefixed_param_names:
                 continue
-            prefixed_param_names.append(self.flat_param_to_prefixed_param_names[flat_param])
+            prefixed_param_names.append(
+                self.flat_param_to_prefixed_param_names[flat_param]
+            )
         return prefixed_param_names
 
     def next_iter(self):
@@ -970,6 +982,7 @@
             the sharded strategies that schedule all-gathers. Enabling this can
             help lower the number of CUDA malloc retries.
     """
+
     def __init__(
         self,
         module: nn.Module,
@@ -1062,10 +1075,16 @@
         self._buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
 
         self._check_single_device_module(module, ignored_params)
-        device_from_device_id: Optional[torch.device] = self._get_device_from_device_id(device_id)
-        self._materialize_module(module, param_init_fn, ignored_params, device_from_device_id)
+        device_from_device_id: Optional[torch.device] = self._get_device_from_device_id(
+            device_id
+        )
+        self._materialize_module(
+            module, param_init_fn, ignored_params, device_from_device_id
+        )
         self._move_module_to_device(module, ignored_params, device_from_device_id)
-        self.compute_device = self._get_compute_device(module, ignored_params, device_from_device_id)
+        self.compute_device = self._get_compute_device(
+            module, ignored_params, device_from_device_id
+        )
         params_to_flatten = list(self._get_orig_params(module, ignored_params))
         if sync_module_states:
             self._sync_module_states(module, params_to_flatten)
@@ -1098,7 +1117,10 @@
             self.params.append(handle.flat_param)
             self._register_param_handle(handle)
             handle.shard()
-            if self.cpu_offload.offload_params and handle.flat_param.device != torch.device("cpu"):
+            if (
+                self.cpu_offload.offload_params
+                and handle.flat_param.device != torch.device("cpu")
+            ):
                 handle.flat_param_to(torch.device("cpu"))
         if not use_orig_params:
             self._check_orig_params_flattened(ignored_params)
@@ -1301,8 +1323,7 @@
         self,
         device_id: Optional[Union[int, torch.device]],
     ) -> Optional[torch.device]:
-        """
-        """
+        """ """
         if device_id is None:
             return None
         device = (
@@ -1341,11 +1362,15 @@
         ``reset_parameters()``, and for torchdistX fake tensors, this calls
         ``deferred_init.materialize_module()``.
         """
-        is_meta_module = any(p.is_meta for p in self._get_orig_params(module, ignored_params))
+        is_meta_module = any(
+            p.is_meta for p in self._get_orig_params(module, ignored_params)
+        )
         is_torchdistX_deferred_init = (
             not is_meta_module
             and _TORCHDISTX_AVAIL
-            and any(fake.is_fake(p) for p in self._get_orig_params(module, ignored_params))
+            and any(
+                fake.is_fake(p) for p in self._get_orig_params(module, ignored_params)
+            )
         )
         if (
             is_meta_module or is_torchdistX_deferred_init
@@ -1357,7 +1382,9 @@
             param_init_fn(module)
         elif is_meta_module:
             # Run default meta device initialization
-            materialization_device = device_from_device_id or torch.cuda.current_device()
+            materialization_device = (
+                device_from_device_id or torch.cuda.current_device()
+            )
             module.to_empty(device=materialization_device)
             try:
                 with torch.no_grad():
@@ -1483,7 +1510,10 @@
                 module_states.append(buffer.detach())
         module_states.extend(param.detach() for param in params)
         _sync_params_and_buffers(
-            self.process_group, module_states, _PARAM_BROADCAST_BUCKET_SIZE, src=0,
+            self.process_group,
+            module_states,
+            _PARAM_BROADCAST_BUCKET_SIZE,
+            src=0,
         )
 
     def _get_orig_params(
@@ -1573,7 +1603,7 @@
         p_assert(
             len(handles) == len(free_unsharded_flat_params),
             "Expects both lists to have equal length but got "
-            f"{len(handles)} and {len(free_unsharded_flat_params)}"
+            f"{len(handles)} and {len(free_unsharded_flat_params)}",
         )
         for handle, free_unsharded_flat_param in zip(
             handles,
@@ -1651,9 +1681,10 @@
             the input ``module``.
         """
         return [
-            submodule for submodule in module.modules()
-            if isinstance(submodule, FullyShardedDataParallel) and
-            (not root_only or submodule.check_is_root())
+            submodule
+            for submodule in module.modules()
+            if isinstance(submodule, FullyShardedDataParallel)
+            and (not root_only or submodule.check_is_root())
         ]
 
     def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
@@ -1728,6 +1759,7 @@
         precision given by ``dtype``, while respecting the existing
         ``requires_grad`` on the tensors.
         """
+
         def cast_fn(x: torch.Tensor) -> torch.Tensor:
             if not torch.is_floating_point(x):
                 return x
@@ -1741,7 +1773,7 @@
         with torch.no_grad():
             return (
                 _apply_to_tensors(cast_fn, args),
-                _apply_to_tensors(cast_fn, kwargs)
+                _apply_to_tensors(cast_fn, kwargs),
             )
 
     def _cast_buffers(
@@ -1775,9 +1807,15 @@
         if memo is None:
             memo = set()
         for module in self.modules():
-            if module is not self and isinstance(module, FullyShardedDataParallel) and recurse:
+            if (
+                module is not self
+                and isinstance(module, FullyShardedDataParallel)
+                and recurse
+            ):
                 # Allow any child FSDP instances to handle their own buffers.
-                module._cast_buffers(device=device, dtype=dtype, memo=memo, recurse=recurse)
+                module._cast_buffers(
+                    device=device, dtype=dtype, memo=memo, recurse=recurse
+                )
             elif module not in memo:
                 memo.add(module)
                 for name, buf in module.named_buffers(recurse=False):
@@ -1863,7 +1901,9 @@
                     fsdp_module.limit_all_gathers = self.limit_all_gathers
                 fsdp_module._free_event_queue = self._free_event_queue
                 fsdp_module._handles_prefetched = self._handles_prefetched
-                fsdp_module._needs_pre_backward_unshard = self._needs_pre_backward_unshard
+                fsdp_module._needs_pre_backward_unshard = (
+                    self._needs_pre_backward_unshard
+                )
                 for handle in fsdp_module._handles:
                     fsdp_module._init_param_attributes(handle)
         if inconsistent_limit_all_gathers:
@@ -1936,13 +1976,11 @@
         # fwd/bwd, it is freed and we only hold on to the full precision shard.
         # As a result, this reduced precision shard is not allocated if we are
         # not in the forward/backward pass.
-        if (
-            self._mixed_precision_enabled_for_params()
-        ):
+        if self._mixed_precision_enabled_for_params():
             p._mp_shard = torch.zeros_like(
                 p._local_shard,
                 device=self.compute_device,
-                dtype=self.mixed_precision.param_dtype
+                dtype=self.mixed_precision.param_dtype,
             )
             _free_storage(p._mp_shard)
 
@@ -1957,7 +1995,8 @@
             # into full_param_padded it can occur without issues and result in
             # full_param_padded having the expected param_dtype.
             full_param_dtype = (
-                p.dtype if not self._mixed_precision_enabled_for_params()
+                p.dtype
+                if not self._mixed_precision_enabled_for_params()
                 else self.mixed_precision.param_dtype
             )
             p._full_param_padded = torch.zeros(  # type: ignore[attr-defined]
@@ -2024,7 +2063,9 @@
         for handles_key in handles_to_prefetch:
             # Prefetch the next set of handles without synchronizing to allow
             # the sync to happen as late as possible to maximize overlap
-            self._unshard(handles_key, self._streams["unshard"], self._streams["pre_unshard"])
+            self._unshard(
+                handles_key, self._streams["unshard"], self._streams["pre_unshard"]
+            )
             self._handles_prefetched[handles_key] = True
 
     def _get_handles_to_prefetch(
@@ -2048,33 +2089,31 @@
         p_assert(
             training_state in valid_training_states,
             f"Prefetching is only supported in {valid_training_states} but "
-            f"currently in {training_state}"
+            f"currently in {training_state}",
         )
         eod = self._exec_order_data
         target_handles_keys: List[_HandlesKey] = []
         if (
-            (
-                training_state == HandleTrainingState.BACKWARD_PRE
-                and self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
-            )
-            or (
-                training_state == HandleTrainingState.BACKWARD_POST
-                and self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
-            )
+            training_state == HandleTrainingState.BACKWARD_PRE
+            and self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
+        ) or (
+            training_state == HandleTrainingState.BACKWARD_POST
+            and self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
         ):
             target_handles_keys = [
-                target_handles_key for target_handles_key in
-                eod.get_handles_to_backward_prefetch(current_handles_key)
+                target_handles_key
+                for target_handles_key in eod.get_handles_to_backward_prefetch(
+                    current_handles_key
+                )
                 if self._needs_pre_backward_unshard.get(target_handles_key, False)
                 and not self._handles_prefetched.get(target_handles_key, False)
             ]
-        elif (
-            training_state == HandleTrainingState.FORWARD
-            and self.forward_prefetch
-        ):
+        elif training_state == HandleTrainingState.FORWARD and self.forward_prefetch:
             target_handles_keys = [
-                target_handles_key for target_handles_key in
-                eod.get_handles_to_forward_prefetch(current_handles_key)
+                target_handles_key
+                for target_handles_key in eod.get_handles_to_forward_prefetch(
+                    current_handles_key
+                )
                 if self._needs_pre_forward_unshard.get(target_handles_key, False)
                 and not self._handles_prefetched.get(target_handles_key, False)
             ]
@@ -2089,7 +2128,7 @@
         training_states = set(handle._training_state for handle in handles_key)
         p_assert(
             len(training_states) == 1,
-            f"Expects uniform training state but got {training_states}"
+            f"Expects uniform training state but got {training_states}",
         )
         return next(iter(training_states))
 
@@ -2157,7 +2196,9 @@
                     "All FSDP modules should have the same type of state_dict_config."
                 )
 
-            expected_state_dict_config_type = _state_dict_type_to_config[state_dict_type]
+            expected_state_dict_config_type = _state_dict_type_to_config[
+                state_dict_type
+            ]
             if expected_state_dict_config_type != type(state_dict_config):
                 raise RuntimeError(
                     f"Expected state_dict_config of type {expected_state_dict_config_type} "
@@ -2200,10 +2241,11 @@
         prev_state_dict_type = None
         prev_state_dict_config = None
         try:
-            prev_state_dict_type, prev_state_dict_config = (
-                FullyShardedDataParallel.set_state_dict_type(
-                    module, state_dict_type, state_dict_config
-                )
+            (
+                prev_state_dict_type,
+                prev_state_dict_config,
+            ) = FullyShardedDataParallel.set_state_dict_type(
+                module, state_dict_type, state_dict_config
             )
             yield
         except Exception as e:
@@ -2233,18 +2275,14 @@
     def _param_fqns(self) -> Iterator[Tuple[str, str, str]]:
         if not self._has_params:
             return
-        for param_name, module_name in (
-            self._handles[0].parameter_module_names()
-        ):
+        for param_name, module_name in self._handles[0].parameter_module_names():
             module_name = self._convert_to_wrapped_module_name(module_name)
             fqn = f"{module_name}{param_name}"
             yield fqn, param_name, module_name
 
     @property
     def _shared_param_fqns(self) -> Iterator[Tuple[str, str, str]]:
-        for param_name, module_name in (
-            self._handles[0].shared_parameter_module_names()
-        ):
+        for param_name, module_name in self._handles[0].shared_parameter_module_names():
             module_name = self._convert_to_wrapped_module_name(module_name)
             fqn = f"{module_name}{param_name}"
             yield fqn, param_name, module_name
@@ -2297,17 +2335,21 @@
         if self._state_dict_type == StateDictType.FULL_STATE_DICT:
             # Get config args
             full_state_dict_config = (
-                self._state_dict_config if self._state_dict_config is not None
+                self._state_dict_config
+                if self._state_dict_config is not None
                 else FullStateDictConfig()
             )
             rank0_only = full_state_dict_config.rank0_only
             offload_to_cpu = full_state_dict_config.offload_to_cpu
             summon_ctx = (
                 self._summon_full_params(
-                    recurse=False, writeback=False, offload_to_cpu=offload_to_cpu, rank0_only=rank0_only
+                    recurse=False,
+                    writeback=False,
+                    offload_to_cpu=offload_to_cpu,
+                    rank0_only=rank0_only,
                 )
-                if self.training_state != TrainingState_.SUMMON_FULL_PARAMS else
-                contextlib.suppress()
+                if self.training_state != TrainingState_.SUMMON_FULL_PARAMS
+                else contextlib.suppress()
             )
             with summon_ctx:
                 # Since buffers are not sharded and stay casted, restore them to their
@@ -2316,10 +2358,7 @@
                 # buffers stay casted after forward/backward. We must have the
                 # call here instead of above because _summon_full_params itself
                 # calls _lazy_init() which would cast the buffers.
-                if (
-                    self._is_root
-                    and self._mixed_precision_enabled_for_buffers()
-                ):
+                if self._is_root and self._mixed_precision_enabled_for_buffers():
                     self._cast_buffers(
                         dtype=self._buffer_name_to_orig_dtype, recurse=False
                     )
@@ -2332,13 +2371,10 @@
                 return {}
 
         elif (
-            self._state_dict_type == StateDictType.LOCAL_STATE_DICT or
-            self._state_dict_type == StateDictType.SHARDED_STATE_DICT
+            self._state_dict_type == StateDictType.LOCAL_STATE_DICT
+            or self._state_dict_type == StateDictType.SHARDED_STATE_DICT
         ):
-            if (
-                self._has_params and
-                not self._handles[0].uses_sharded_strategy
-            ):
+            if self._has_params and not self._handles[0].uses_sharded_strategy:
                 raise RuntimeError(
                     "sharded_state_dict/local_state_dict can only be called "
                     "when parameters are flatten and sharded."
@@ -2352,17 +2388,22 @@
         Runs the forward pass for the wrapped module, inserting FSDP-specific
         pre- and post-forward sharding logic.
         """
-        with torch.autograd.profiler.record_function("FullyShardedDataParallel.forward"):
+        with torch.autograd.profiler.record_function(
+            "FullyShardedDataParallel.forward"
+        ):
             self._lazy_init()
             args, kwargs = self._fsdp_root_pre_forward(*args, **kwargs)
             unused = None
-            unshard_fn = functools.partial(self._pre_forward_unshard, handles=self._handles)
+            unshard_fn = functools.partial(
+                self._pre_forward_unshard, handles=self._handles
+            )
             # Do not free the root's parameters in the post-forward for
             # `FULL_SHARD` with the intention that they are immediately used
             # for backward computation (though this may not be true)
             free_unsharded_flat_params = [
                 not self._is_root
-                and handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
+                and handle._config.sharding_strategy
+                == HandleShardingStrategy.FULL_SHARD
                 for handle in self._handles
             ]
             reshard_fn = functools.partial(
@@ -2375,7 +2416,7 @@
                 p_assert(
                     handle.flat_param.device == self.compute_device,
                     "Expected `FlatParameter` to be on the compute device "
-                    f"{self.compute_device} but got {handle.flat_param.device}"
+                    f"{self.compute_device} but got {handle.flat_param.device}",
                 )
             output = self._fsdp_wrapped_module(*args, **kwargs)
             return self._post_forward(self._handles, reshard_fn, unused, unused, output)
@@ -2418,7 +2459,9 @@
     ) -> None:
         """Unshards parameters in the pre-forward."""
         if handles:
-            self._unshard(handles, self._streams["unshard"], self._streams["pre_unshard"])
+            self._unshard(
+                handles, self._streams["unshard"], self._streams["pre_unshard"]
+            )
             handles_key = tuple(handles)
             self._needs_pre_forward_unshard[handles_key] = False
             torch.cuda.current_stream().wait_stream(self._streams["unshard"])
@@ -2476,7 +2519,9 @@
         if self._mixed_precision_enabled_for_params():
             input_dtype = self.mixed_precision.param_dtype
             args, kwargs = self._cast_fp_inputs_to_dtype(
-                input_dtype, *args, **kwargs,
+                input_dtype,
+                *args,
+                **kwargs,
             )
         return args, kwargs
 
@@ -2525,7 +2570,7 @@
         offload_to_cpu: bool = False,
         with_grads: bool = False,
     ) -> Generator:
-        r""" A context manager to expose full params for FSDP instances.
+        r"""A context manager to expose full params for FSDP instances.
         Can be useful *after* forward/backward for a model to get
         the params for additional processing or checking. It can take a non-FSDP
         module and will summon full params for all contained FSDP modules as
@@ -2663,7 +2708,9 @@
             handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
 
         self._clear_grads_if_needed()
-        free_unsharded_flat_params = [handle.needs_unshard() for handle in self._handles]
+        free_unsharded_flat_params = [
+            handle.needs_unshard() for handle in self._handles
+        ]
         # No need to call `wait_stream()` since we unshard in the computation
         # stream directly
         computation_stream = torch.cuda.current_stream()
@@ -2742,7 +2789,7 @@
                 handle.rank,
                 handle.world_size,
             )
-            handle.flat_param._local_shard[:param_shard.numel()].copy_(param_shard)
+            handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard)
             if writeback_grad:
                 existing_grad = handle.sharded_grad
                 if existing_grad is not None:
@@ -2751,7 +2798,7 @@
                         handle.rank,
                         handle.world_size,
                     )
-                    existing_grad[:grad_shard.numel()].copy_(grad_shard)
+                    existing_grad[: grad_shard.numel()].copy_(grad_shard)
 
     @contextlib.contextmanager
     def _unflatten_as_params(self) -> Generator:
@@ -2825,7 +2872,7 @@
         p_assert(
             len(self._handles) <= 1,
             "Expects <=1 handle per FSDP instance; needs to be refactored "
-            "for >1 handle (e.g. non-recursive wrapping)"
+            "for >1 handle (e.g. non-recursive wrapping)",
         )
         if not self._handles:
             return
@@ -2833,7 +2880,7 @@
         p_assert(
             handle._use_orig_params,
             f"Inconsistent `_use_orig_params` -- FSDP: {self._use_orig_params} "
-            f"handle: {handle._use_orig_params}"
+            f"handle: {handle._use_orig_params}",
         )
         handle._deregister_orig_params()
         self._register_flat_param()
@@ -2973,7 +3020,9 @@
 
                 # If the handles have been prefetched, this `_unshard()` simply
                 # switches to using the unsharded parameter
-                self._unshard(_handles, self._streams["unshard"], self._streams["pre_unshard"])
+                self._unshard(
+                    _handles, self._streams["unshard"], self._streams["pre_unshard"]
+                )
                 torch.cuda.current_stream().wait_stream(self._streams["unshard"])
 
                 # Set this to `False` to ensure that a mistargeted prefetch
@@ -3022,7 +3071,7 @@
             p_assert(
                 temp_flat_param.grad_fn is not None,
                 "The `grad_fn` is needed to access the `AccumulateGrad` and "
-                "register the post-backward hook"
+                "register the post-backward hook",
             )
             acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
             hook_handle = acc_grad.register_hook(
@@ -3055,11 +3104,16 @@
         ):
             # First hook callback will see PRE state. If we have multiple params,
             # then subsequent hook callbacks will see POST state.
-            self._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST])
+            self._assert_state(
+                [TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST]
+            )
             self.training_state = TrainingState_.BACKWARD_POST
             handle._training_state = HandleTrainingState.BACKWARD_POST
 
-            if self._use_param_exec_order_policy() and self._param_exec_order_prep_stage:
+            if (
+                self._use_param_exec_order_policy()
+                and self._param_exec_order_prep_stage
+            ):
                 # In self._fsdp_params_exec_order, the parameters are ordered based on
                 # the execution order in the backward pass in the first iteration.
                 self._fsdp_params_exec_order.append(param)
@@ -3103,7 +3157,9 @@
                     # TODO: Make this a communication hook when communication hooks
                     # are implemented for FSDP. Note that this is a noop if the
                     # reduce_dtype matches the param dtype.
-                    param.grad.data = param.grad.data.to(self.mixed_precision.reduce_dtype)
+                    param.grad.data = param.grad.data.to(
+                        self.mixed_precision.reduce_dtype
+                    )
 
                 if self._exec_order_data.is_first_iter:
                     # For all sharding strategies communication is performed through `_communication_hook`:
@@ -3112,11 +3168,11 @@
                     # and `_communication_hook_state`, required for communication not `None`.`
                     p_assert(
                         self._communication_hook is not None,
-                        "Communication hook should not be None"
+                        "Communication hook should not be None",
                     )
                     p_assert(
                         self._communication_hook_state is not None,
-                        "Communication hook state should not be None"
+                        "Communication hook state should not be None",
                     )
                 grad = param.grad.data
                 if handle.uses_sharded_strategy:
@@ -3138,7 +3194,9 @@
                     num_pad = self.world_size * chunks[0].numel() - grad.numel()
                     input_flattened = F.pad(grad_flatten, [0, num_pad])
                     output = torch.zeros_like(chunks[0])
-                    self._communication_hook(self._communication_hook_state, input_flattened, output)
+                    self._communication_hook(
+                        self._communication_hook_state, input_flattened, output
+                    )
 
                     self._cast_grad_to_param_dtype(output, param)
 
@@ -3153,13 +3211,13 @@
                             param._saved_grad_shard.shape == output.shape,  # type: ignore[attr-defined]
                             "Shape mismatch when accumulating gradients: "  # type: ignore[attr-defined]
                             f"existing grad shape={param._saved_grad_shard.shape} "
-                            f"new grad shape={output.shape}"  # type: ignore[attr-defined]
+                            f"new grad shape={output.shape}",  # type: ignore[attr-defined]
                         )
                         p_assert(
                             param._saved_grad_shard.device == output.device,  # type: ignore[attr-defined]
                             "Device mismatch when accumulating gradients: "  # type: ignore[attr-defined]
                             f"existing grad device={param._saved_grad_shard.device} "
-                            f"new grad device={output.device}"  # type: ignore[attr-defined]
+                            f"new grad device={output.device}",  # type: ignore[attr-defined]
                         )
                         param._saved_grad_shard += output  # type: ignore[attr-defined]
                     else:
@@ -3167,7 +3225,9 @@
                     grad = param._saved_grad_shard  # type: ignore[attr-defined]
                 else:
                     if self.sharding_strategy == ShardingStrategy.NO_SHARD:
-                        self._communication_hook(self._communication_hook_state, param.grad)
+                        self._communication_hook(
+                            self._communication_hook_state, param.grad
+                        )
 
                     # For NO_SHARD keeping grads in the reduced precision, we
                     # can simply omit the cast as needed, we can't do this for
@@ -3221,12 +3281,9 @@
         dtype cast happens in the hook instead.
         """
         self._assert_state(TrainingState_.BACKWARD_POST)
-        if (
-            not self._low_precision_hook_enabled()
-            and (
-                self._mixed_precision_enabled_for_params()
-                or self._mixed_precision_enabled_for_reduce()
-            )
+        if not self._low_precision_hook_enabled() and (
+            self._mixed_precision_enabled_for_params()
+            or self._mixed_precision_enabled_for_reduce()
         ):
             low_prec_grad_data = grad.data
             grad.data = grad.data.to(dtype=param.dtype)
@@ -3236,9 +3293,8 @@
 
     def _should_free_unsharded_flat_param(self, handle: FlatParamHandle):
         return (
-            (self._sync_gradients and handle.uses_sharded_strategy)
-            or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
-        )
+            self._sync_gradients and handle.uses_sharded_strategy
+        ) or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
 
     def _queue_wait_for_post_backward(self) -> None:
         """
@@ -3247,7 +3303,7 @@
         """
         p_assert(
             self._is_root,
-            "`_queue_wait_for_post_backward()` should be called on the root FSDP instance"
+            "`_queue_wait_for_post_backward()` should be called on the root FSDP instance",
         )
         if self._post_backward_callback_queued:
             return
@@ -3295,18 +3351,21 @@
                     # TODO: This already-resharded check is brittle:
                     # https://github.com/pytorch/pytorch/issues/83956
                     already_resharded = (
-                        handle.flat_param.data_ptr() == handle.flat_param._local_shard.data_ptr()
+                        handle.flat_param.data_ptr()
+                        == handle.flat_param._local_shard.data_ptr()
                     )
                     if already_resharded:
                         continue
-                    free_unsharded_flat_params.append(self._should_free_unsharded_flat_param(handle))
+                    free_unsharded_flat_params.append(
+                        self._should_free_unsharded_flat_param(handle)
+                    )
                     handles_to_reshard.append(handle)
                 self._reshard(handles_to_reshard, free_unsharded_flat_params)
             except Exception as e:
                 p_assert(
                     False,
                     f"Got exception while resharding module {fsdp_module}: {str(e)}",
-                    raise_assertion_error=False
+                    raise_assertion_error=False,
                 )
                 raise e
 
@@ -3318,7 +3377,7 @@
                     if hasattr(p, "_post_backward_hook_state"):
                         p_assert(
                             len(p._post_backward_hook_state) == 2,  # type: ignore[attr-defined]
-                            "p._post_backward_hook_state fields are not valid."
+                            "p._post_backward_hook_state fields are not valid.",
                         )
                         p._post_backward_hook_state[1].remove()  # type: ignore[attr-defined]
                         delattr(p, "_post_backward_hook_state")
@@ -3331,8 +3390,8 @@
                         continue
                     handle.prepare_gradient_for_optim()
                     p_assert(
-                        hasattr(p, '_post_backward_called'),
-                        "Expected flag _post_backward_called to be set on param."
+                        hasattr(p, "_post_backward_called"),
+                        "Expected flag _post_backward_called to be set on param.",
                     )
                     # Reset _post_backward_called in preparation for the next iteration.
                     p._post_backward_called = False
@@ -3479,22 +3538,25 @@
         norm_type = float(norm_type)
         # Compute the local gradient norm (only including this rank's shard
         # of the gradients)
-        local_norm = _get_grad_norm(self.parameters(), norm_type).to(self.compute_device)
+        local_norm = _get_grad_norm(self.parameters(), norm_type).to(
+            self.compute_device
+        )
         # Reconstruct the total gradient norm depending on the norm type
         if norm_type == math.inf:
             total_norm = local_norm
-            dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
+            dist.all_reduce(
+                total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group
+            )
         else:
-            total_norm = local_norm ** norm_type
+            total_norm = local_norm**norm_type
             dist.all_reduce(total_norm, group=self.process_group)
             total_norm = total_norm ** (1.0 / norm_type)
         if self.cpu_offload.offload_params:
             total_norm = total_norm.cpu()
 
-        clip_coef = (
-            torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device)
-            / (total_norm + 1e-6)
-        )
+        clip_coef = torch.tensor(
+            max_norm, dtype=total_norm.dtype, device=total_norm.device
+        ) / (total_norm + 1e-6)
         # Multiplying by the clamped coefficient is meaningless when it is
         # equal to 1, but it avoids the host-device sync that would result from
         # `if clip_coef < 1`
@@ -3537,9 +3599,12 @@
     def full_optim_state_dict(
         model: torch.nn.Module,
         optim: torch.optim.Optimizer,
-        optim_input: Optional[Union[
-            List[Dict[str, Any]], Iterable[torch.nn.Parameter],
-        ]] = None,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
         rank0_only: bool = True,
         group: Optional[dist.ProcessGroup] = None,
     ) -> Dict[str, Any]:
@@ -3592,7 +3657,8 @@
         FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
         FullyShardedDataParallel._warn_optim_input(optim_input)
         using_optim_input = FullyShardedDataParallel._is_using_optim_input(
-            optim_input, optim,
+            optim_input,
+            optim,
         )
         return _optim_state_dict(
             model=model,
@@ -3610,7 +3676,8 @@
         optim: torch.optim.Optimizer,
         optim_input: Optional[
             Union[
-                List[Dict[str, Any]], Iterable[torch.nn.Parameter],
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
             ]
         ] = None,
         group: Optional[dist.ProcessGroup] = None,
@@ -3629,7 +3696,8 @@
         FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
         FullyShardedDataParallel._warn_optim_input(optim_input)
         using_optim_input = FullyShardedDataParallel._is_using_optim_input(
-            optim_input, optim,
+            optim_input,
+            optim,
         )
         # TODO: The ultimate goal of the optimizer state APIs should be the same
         # as state_dict/load_state_dict -- using one API to get optimizer states
@@ -3655,7 +3723,8 @@
         model: torch.nn.Module,
         optim_input: Optional[
             Union[
-                List[Dict[str, Any]], Iterable[torch.nn.Parameter],
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
             ]
         ] = None,
         optim: Optional[torch.optim.Optimizer] = None,
@@ -3717,13 +3786,20 @@
         FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
         FullyShardedDataParallel._warn_optim_input(optim_input)
         using_optim_input = FullyShardedDataParallel._is_using_optim_input(
-            optim_input, optim,
+            optim_input,
+            optim,
         )
         sharded_osd = _flatten_optim_state_dict(
-            full_optim_state_dict, model, True,
+            full_optim_state_dict,
+            model,
+            True,
         )
         return _rekey_sharded_optim_state_dict(
-            sharded_osd, model, optim, optim_input, using_optim_input,
+            sharded_osd,
+            model,
+            optim,
+            optim_input,
+            using_optim_input,
         )
 
     @staticmethod
@@ -3732,7 +3808,8 @@
         model: torch.nn.Module,
         optim_input: Optional[
             Union[
-                List[Dict[str, Any]], Iterable[torch.nn.Parameter],
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
             ]
         ] = None,
         optim: Optional[torch.optim.Optimizer] = None,
@@ -3756,7 +3833,8 @@
         FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
         FullyShardedDataParallel._warn_optim_input(optim_input)
         using_optim_input = FullyShardedDataParallel._is_using_optim_input(
-            optim_input, optim,
+            optim_input,
+            optim,
         )
         # TODO: The implementation is the same as ``shard_full_optim_state_dict``.
         # See the TODO in ``shard_full_optim_state_dict`` for the future
@@ -3767,16 +3845,23 @@
             shard_state=True,
         )
         return _rekey_sharded_optim_state_dict(
-            flattened_osd, model, optim, optim_input, using_optim_input,
+            flattened_osd,
+            model,
+            optim,
+            optim_input,
+            using_optim_input,
         )
 
     @staticmethod
     def scatter_full_optim_state_dict(
         full_optim_state_dict: Optional[Dict[str, Any]],
         model: torch.nn.Module,
-        optim_input: Optional[Union[
-            List[Dict[str, Any]], Iterable[torch.nn.Parameter],
-        ]] = None,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
         optim: Optional[torch.optim.Optimizer] = None,
         group: Optional[Any] = None,
     ) -> Dict[str, Any]:
@@ -3838,7 +3923,8 @@
         FullyShardedDataParallel._raise_on_use_orig_params_optim_checkpoint(model)
         FullyShardedDataParallel._warn_optim_input(optim_input)
         using_optim_input = FullyShardedDataParallel._is_using_optim_input(
-            optim_input, optim,
+            optim_input,
+            optim,
         )
         # Try to use the passed-in process group, the model's process group,
         # or the default process group (i.e. `None`) in that priority order
@@ -3848,8 +3934,9 @@
         world_size = dist.get_world_size(group)
         # Check for a valid broadcast device, preferring GPU when available
         using_nccl = dist.distributed_c10d._check_for_nccl_backend(group)
-        broadcast_device = torch.device("cuda") if torch.cuda.is_available() \
-            else torch.device("cpu")
+        broadcast_device = (
+            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        )
         if using_nccl and not torch.cuda.is_available():
             raise RuntimeError("NCCL requires a GPU for collectives")
         # Flatten the optimizer state dict and construct a copy with the
@@ -3867,18 +3954,28 @@
         # Broadcast the optim state dict without positive-dimension tensor
         # state and the FSDP parameter IDs from rank 0 to all ranks
         processed_osd = _broadcast_processed_optim_state_dict(
-            processed_osd if rank == 0 else None, rank, group,
+            processed_osd if rank == 0 else None,
+            rank,
+            group,
         )
         # Broadcast positive-dimension tensor state (both sharded tensors for
         # FSDP parameters and unsharded tensors for non-FSDP parameters)
         sharded_osd = _broadcast_pos_dim_tensor_states(
-            processed_osd, flat_osd if rank == 0 else None, rank, world_size,
-            group, broadcast_device,
+            processed_osd,
+            flat_osd if rank == 0 else None,
+            rank,
+            world_size,
+            group,
+            broadcast_device,
         )
         # Rekey the optimizer state dict to use parameter IDs according to this
         # rank's `optim`
         sharded_osd = _rekey_sharded_optim_state_dict(
-            sharded_osd, model, optim, optim_input, using_optim_input,
+            sharded_osd,
+            model,
+            optim,
+            optim_input,
+            using_optim_input,
         )
         return sharded_osd
 
@@ -3887,9 +3984,12 @@
         optim_state_dict: Dict[str, Any],
         optim_state_key_type: OptimStateKeyType,
         model: torch.nn.Module,
-        optim_input: Optional[Union[
-            List[Dict[str, Any]], Iterable[torch.nn.Parameter],
-        ]] = None,
+        optim_input: Optional[
+            Union[
+                List[Dict[str, Any]],
+                Iterable[torch.nn.Parameter],
+            ]
+        ] = None,
         optim: Optional[torch.optim.Optimizer] = None,
     ) -> Dict[str, Any]:
         """
@@ -3926,30 +4026,30 @@
         """
         FullyShardedDataParallel._warn_optim_input(optim_input)
         using_optim_input = FullyShardedDataParallel._is_using_optim_input(
-            optim_input, optim,
+            optim_input,
+            optim,
         )
         assert optim_state_key_type in (
-            OptimStateKeyType.PARAM_NAME, OptimStateKeyType.PARAM_ID,
+            OptimStateKeyType.PARAM_NAME,
+            OptimStateKeyType.PARAM_ID,
         )
         osd = optim_state_dict  # alias
         # Validate that the existing parameter keys are uniformly typed
-        uses_param_name_mask = [
-            type(param_key) is str for param_key in osd["state"]
-        ]
-        uses_param_id_mask = [
-            type(param_key) is int for param_key in osd["state"]
-        ]
-        if (
-            (any(uses_param_name_mask) and not all(uses_param_name_mask))
-            or (any(uses_param_id_mask) and not all(uses_param_id_mask))
+        uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]]
+        uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]]
+        if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or (
+            any(uses_param_id_mask) and not all(uses_param_id_mask)
         ):
             error_msg = f"Invalid parameter keys: {osd['state'].keys()}"
             raise ValueError(error_msg)
         # Return directly if the existing key type matches the target key type
-        if (optim_state_key_type == OptimStateKeyType.PARAM_NAME and
-            all(uses_param_name_mask)) or \
-            (optim_state_key_type == OptimStateKeyType.PARAM_ID and
-                all(uses_param_id_mask)):
+        if (
+            optim_state_key_type == OptimStateKeyType.PARAM_NAME
+            and all(uses_param_name_mask)
+        ) or (
+            optim_state_key_type == OptimStateKeyType.PARAM_ID
+            and all(uses_param_id_mask)
+        ):
             return osd
         # Otherwise, actually perform the re-keying
         new_osd = {}
@@ -3969,10 +4069,12 @@
             }
             new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
             for param_group in new_osd["param_groups"]:
-                param_group["params"] = sorted([
-                    param_id_to_param_name[param_id]
-                    for param_id in param_group["params"]
-                ])
+                param_group["params"] = sorted(
+                    [
+                        param_id_to_param_name[param_id]
+                        for param_id in param_group["params"]
+                    ]
+                )
             return new_osd
         elif optim_state_key_type == OptimStateKeyType.PARAM_ID:  # name -> ID
             param_name_to_param = _get_param_name_to_param(model)
@@ -3994,10 +4096,12 @@
             }
             new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
             for param_group in new_osd["param_groups"]:
-                param_group["params"] = sorted([
-                    param_name_to_param_id[param_name]
-                    for param_name in param_group["params"]
-                ])
+                param_group["params"] = sorted(
+                    [
+                        param_name_to_param_id[param_name]
+                        for param_name in param_group["params"]
+                    ]
+                )
             return new_osd
         return new_osd  # should never reach here
 
@@ -4056,12 +4160,17 @@
 
         """
         if not self.check_is_root():
-            raise AssertionError("register_comm_hook can only be called on a root instance.")
+            raise AssertionError(
+                "register_comm_hook can only be called on a root instance."
+            )
         for submodule in self.fsdp_modules(self):
-            assert not submodule._hook_registered, "communication hook can be only registered once"
+            assert (
+                not submodule._hook_registered
+            ), "communication hook can be only registered once"
             submodule._hook_registered = True
-            assert submodule._communication_hook == self._get_default_comm_hook(),\
-                f"communication hook should be default, but it is {submodule._communication_hook.__name__} instead"
+            assert (
+                submodule._communication_hook == self._get_default_comm_hook()
+            ), f"communication hook should be default, but it is {submodule._communication_hook.__name__} instead"
             submodule._communication_hook_state = state
             submodule._communication_hook = hook
 
@@ -4073,10 +4182,7 @@
             assert (
                 auto_wrap_policy.tracing_config is None
             ), "tracing_config should be None when torch.fx is not enabled"
-        elif isinstance(
-            auto_wrap_policy.tracing_config,
-            TracingConfig
-        ):
+        elif isinstance(auto_wrap_policy.tracing_config, TracingConfig):
             tracer = auto_wrap_policy.tracing_config.tracer
             execution_info = _init_execution_info(module)
 
@@ -4110,8 +4216,7 @@
         # A list that stores the flatten parameters and its name based on the parameter execution order
         self._fsdp_params_exec_order: List[FlatParameter] = []
         if _TORCH_FX_AVAIL and isinstance(
-            auto_wrap_policy.tracing_config,
-            TracingConfig
+            auto_wrap_policy.tracing_config, TracingConfig
         ):
             # Initialize a dict that maps each module to its parent FSDP wrap
             module_to_fsdp: Dict[nn.Module, FullyShardedDataParallel] = dict()
@@ -4137,8 +4242,7 @@
 
     def _use_param_exec_order_policy(self) -> bool:
         return (
-            hasattr(self, "_param_exec_order_policy")
-            and self._param_exec_order_policy
+            hasattr(self, "_param_exec_order_policy") and self._param_exec_order_policy
         )
 
     def _is_param_exec_order_prep_stage(self) -> bool:
@@ -4148,8 +4252,8 @@
         )
         if not is_prep_stage:
             for p in self.parameters():
-                assert (
-                    not hasattr(p, "_params_exec_order_hook_handle")
+                assert not hasattr(
+                    p, "_params_exec_order_hook_handle"
                 ), "When not in execution order prep stage, all _params_exec_order_hook_handle should be removed."
         return is_prep_stage
 
@@ -4168,7 +4272,9 @@
     grads = [param.grad for param in params_with_grad]
     grad_dtypes = set(grad.dtype for grad in grads)
     if len(grad_dtypes) != 1:
-        raise ValueError(f"Requires uniform dtype across all gradients but got {grad_dtypes}")
+        raise ValueError(
+            f"Requires uniform dtype across all gradients but got {grad_dtypes}"
+        )
     # Compute the gradient norm in FP32, where we treat the gradients as a
     # single vector
     grad_norm = torch.linalg.vector_norm(
@@ -4206,15 +4312,14 @@
             in the module walk order; if ``False``, then includes all of the
             unflattened parameter names.
     """
+
     def module_fn(module, prefix, param_to_unflat_param_names):
         for param_name, param in module.named_parameters(recurse=False):
             module_prefixed_param_names = (
-                param._fqns if type(param) is FlatParameter
-                else [param_name]
+                param._fqns if type(param) is FlatParameter else [param_name]
             )  # prefixed from `module`
             fully_prefixed_param_names = [
-                clean_tensor_name(prefix + name)
-                for name in module_prefixed_param_names
+                clean_tensor_name(prefix + name) for name in module_prefixed_param_names
             ]  # fully prefixed from the top level including `prefix`
             # If this parameter has already been visited, then it is a
             # shared parameter; then, only take the first parameter name
@@ -4229,7 +4334,10 @@
 
     param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
     return _apply_to_modules(
-        model, module_fn, return_fn, param_to_unflat_param_names,
+        model,
+        module_fn,
+        return_fn,
+        param_to_unflat_param_names,
     )
 
 
@@ -4250,16 +4358,16 @@
     """
     param_to_param_names = _get_param_to_unflat_param_names(model)
     for param_names in param_to_param_names.values():
-        assert len(param_names) > 0, "`_get_param_to_unflat_param_names()` " \
-            "should not construct empty lists"
+        assert len(param_names) > 0, (
+            "`_get_param_to_unflat_param_names()` " "should not construct empty lists"
+        )
         if len(param_names) > 1:
             raise RuntimeError(
                 "Each parameter should only map to one parameter name but got "
                 f"{len(param_names)}: {param_names}"
             )
     param_to_param_name = {
-        param: param_names[0]
-        for param, param_names in param_to_param_names.items()
+        param: param_names[0] for param, param_names in param_to_param_names.items()
     }
     return param_to_param_name
 
diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py
index 27ba44e..86dbfd7 100644
--- a/torch/distributed/fsdp/sharded_grad_scaler.py
+++ b/torch/distributed/fsdp/sharded_grad_scaler.py
@@ -1,12 +1,12 @@
-from collections import abc, defaultdict
 import logging
+from collections import abc, defaultdict
 from typing import Dict, List, Optional, Union
 
 import torch
-from torch.cuda import FloatTensor  # type: ignore[attr-defined]
-from torch.cuda.amp.grad_scaler import GradScaler, OptState, _MultiDeviceReplicator
-from torch.distributed.distributed_c10d import ProcessGroup
 import torch.distributed as dist
+from torch.cuda import FloatTensor  # type: ignore[attr-defined]
+from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
+from torch.distributed.distributed_c10d import ProcessGroup
 from torch.optim.sgd import SGD
 
 
@@ -23,6 +23,7 @@
     Lazily serves tensor to request device. This class extends
     _MultiDeviceReplicator to allow support for "cpu" as a device.
     """
+
     def __init__(self, master_tensor: torch.Tensor) -> None:
         assert _is_supported_device(master_tensor)
         self.master = master_tensor
@@ -77,9 +78,10 @@
         process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
             process group for sharding
     """
+
     def __init__(
         self,
-        init_scale: float = 2.0 ** 16,
+        init_scale: float = 2.0**16,
         backoff_factor: float = 0.5,
         growth_factor: float = 2.0,
         growth_interval: int = 2000,
@@ -97,7 +99,9 @@
             self.process_group = process_group
             self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
 
-    def scale(self, outputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
+    def scale(
+        self, outputs: Union[torch.Tensor, List[torch.Tensor]]
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
         if not self._enabled:
             return outputs
 
@@ -106,7 +110,9 @@
             if self._scale is None:
                 self._lazy_init_scale_growth_tracker(outputs.device)
             assert self._scale is not None
-            scaled_output = outputs * self._scale.to(device=outputs.device, non_blocking=True)
+            scaled_output = outputs * self._scale.to(
+                device=outputs.device, non_blocking=True
+            )
             # Here we ensure the return dtype is the same as the outputs dtype.
             # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
             # format (fp16, bf16) and so the scaled loss should be of the same dtype.
@@ -114,7 +120,9 @@
 
         stash: List[_GeneralMultiDeviceReplicator] = []
 
-        def apply_scale(val: Union[torch.Tensor, abc.Iterable]) -> Union[torch.Tensor, abc.Iterable]:
+        def apply_scale(
+            val: Union[torch.Tensor, abc.Iterable]
+        ) -> Union[torch.Tensor, abc.Iterable]:
             if isinstance(val, torch.Tensor):
                 assert _is_supported_device(val)
                 if len(stash) == 0:
@@ -150,20 +158,30 @@
         for grad in grads:
             for tensor in grad:
                 if tensor.device != expected_device:
-                    logging.error("tensor device is %s and expected device is %s" % (tensor.device, expected_device))
+                    logging.error(
+                        "tensor device is %s and expected device is %s"
+                        % (tensor.device, expected_device)
+                    )
                     raise ValueError("Gradients must be on the same device.")
 
                 # check for non_overlapping_and_dense doesn't exist in the python world
                 # as remarked here https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/AmpKernels.cu#L108
                 # we assume tensor is not MTA(multi tensor apply) safe. iterate through each item regardless of dtype
-                if torch.isinf(tensor).any().item() is True or torch.isnan(tensor).any().item() is True:
+                if (
+                    torch.isinf(tensor).any().item() is True
+                    or torch.isnan(tensor).any().item() is True
+                ):
                     found_inf.data = torch.tensor([1.0])
                     break
                 else:
                     tensor.data *= inv_scale.item()
 
     def _unscale_grads_(
-        self, optimizer: SGD, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool = True
+        self,
+        optimizer: SGD,
+        inv_scale: torch.Tensor,
+        found_inf: torch.Tensor,
+        allow_fp16: bool = True,
     ) -> Dict[torch.device, torch.Tensor]:
         per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
         per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
@@ -195,7 +213,9 @@
                     else:
                         to_unscale = param.grad
 
-                    per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)
+                    per_device_and_dtype_grads[to_unscale.device][
+                        to_unscale.dtype
+                    ].append(to_unscale)
 
             for device, per_dtype_grads in per_device_and_dtype_grads.items():
                 for grads in per_dtype_grads.values():
@@ -222,16 +242,22 @@
         optimizer_state = self._per_optimizer_states[id(optimizer)]
 
         if optimizer_state["stage"] is OptState.UNSCALED:
-            raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
+            raise RuntimeError(
+                "unscale_() has already been called on this optimizer since the last update()."
+            )
         elif optimizer_state["stage"] is OptState.STEPPED:
             raise RuntimeError("unscale_() is being called after step().")
 
         # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
         assert self._scale is not None
         inv_scale = self._scale.double().reciprocal().float()
-        found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
+        found_inf = torch.full(
+            (1,), 0.0, dtype=torch.float32, device=self._scale.device
+        )
 
-        optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, True)
+        optimizer_state["found_inf_per_device"] = self._unscale_grads_(
+            optimizer, inv_scale, found_inf, True
+        )
         optimizer_state["stage"] = OptState.UNSCALED
 
         # Synchronize the detected inf across the ranks
@@ -241,10 +267,18 @@
         for v in optimizer_state["found_inf_per_device"].values():
             if v.device.type == "cpu":
                 v_on_cuda = v.cuda()
-                future_handles.append(dist.all_reduce(v_on_cuda, async_op=True, group=self.process_group).get_future())
+                future_handles.append(
+                    dist.all_reduce(
+                        v_on_cuda, async_op=True, group=self.process_group
+                    ).get_future()
+                )
                 v.copy_(v_on_cuda.cpu())
             else:
-                future_handles.append(dist.all_reduce(v, async_op=True, group=self.process_group).get_future())
+                future_handles.append(
+                    dist.all_reduce(
+                        v, async_op=True, group=self.process_group
+                    ).get_future()
+                )
 
         # Make sure that the calls are done before moving out.
         if future_handles:
diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py
index 8013da8..c529bcd 100644
--- a/torch/distributed/fsdp/wrap.py
+++ b/torch/distributed/fsdp/wrap.py
@@ -5,22 +5,11 @@
 
 import contextlib
 from dataclasses import dataclass
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    Generator,
-    Optional,
-    Set,
-    Tuple,
-    Type,
-    cast,
-)
+from typing import Any, Callable, cast, Dict, Generator, Optional, Set, Tuple, Type
 
 import torch.nn as nn
 from torch.nn.modules.batchnorm import _BatchNorm
 
-
 __all__ = [
     "always_wrap_policy",
     "lambda_auto_wrap_policy",
@@ -41,11 +30,9 @@
     """
     return True
 
+
 def lambda_auto_wrap_policy(
-    module: nn.Module,
-    recurse: bool,
-    unwrapped_params: int,
-    lambda_fn: Callable
+    module: nn.Module, recurse: bool, unwrapped_params: int, lambda_fn: Callable
 ) -> bool:
     """
     A convenient auto wrap policy to wrap submodules based on an arbitrary user
@@ -78,6 +65,7 @@
         # if not recursing, decide whether we should wrap for the leaf node or reminder
         return lambda_fn(module)
 
+
 def transformer_auto_wrap_policy(
     module: nn.Module,
     recurse: bool,
@@ -121,6 +109,7 @@
         # if not recursing, decide whether we should wrap for the leaf node or reminder
         return isinstance(module, tuple(transformer_layer_cls))
 
+
 def _wrap_batchnorm_individually(
     module: nn.Module,
     recurse: bool,
@@ -138,6 +127,7 @@
         # BN layer or not.
         return isinstance(module, _BatchNorm)
 
+
 def _or_policy(
     module: nn.Module,
     recurse: bool,
@@ -148,9 +138,7 @@
     A policy that wraps ``module`` if any policy in the passed in iterable of
     ``policies`` returns ``True``.
     """
-    return any(
-        policy(module, recurse, unwrapped_params) for policy in policies
-    )
+    return any(policy(module, recurse, unwrapped_params) for policy in policies)
 
 
 def size_based_auto_wrap_policy(
@@ -333,13 +321,14 @@
     ``full``, ``full_like``, ``eye``, ``empty``, ``tensor``). For those cases,
     users can set ``tracing_config = None`` to disable symbolic tracing.
     """
+
     init_policy: Callable = always_wrap_policy
     tracing_config: Any = None
 
 
 def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
     assert wrapper_cls is not None
-    if hasattr(module, '_wrap_overrides'):
+    if hasattr(module, "_wrap_overrides"):
         # If module has a _wrap_overrides attribute, we force overriding the
         # FSDP config with these attributes for this module. Currently this
         # is only used to disable mixed precision for BatchNorm when
@@ -357,7 +346,7 @@
     ignored_modules: Set[nn.Module],
     ignored_params: Set[nn.Parameter],
     only_wrap_children: bool = False,
-    **kwargs: Any
+    **kwargs: Any,
 ) -> Tuple[nn.Module, int]:
     """
     Automatically wrap child modules of *module* that meet the given
@@ -389,9 +378,7 @@
             pass
 
     # We count all params, assuming none of them are already wrapped.
-    num_params = sum(
-        p.numel() for p in module.parameters() if p not in ignored_params
-    )
+    num_params = sum(p.numel() for p in module.parameters() if p not in ignored_params)
 
     assert auto_wrap_policy is not None
     if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):