[FSDP()][14/N] Refactor pre-forward/post-backward (#87927)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87927
Approved by: https://github.com/mrshenli
diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py
index 93d5e4f..b0d5252 100644
--- a/test/distributed/fsdp/test_fsdp_core.py
+++ b/test/distributed/fsdp/test_fsdp_core.py
@@ -366,12 +366,14 @@
)
input = fsdp_model.module.get_input(torch.device("cuda"))
fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None)
- fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None)
- self.assertFalse(fsdp_model._register_post_backward_hooks.called)
- self.assertFalse(fsdp_model._register_pre_backward_hooks.called)
- fsdp_model(*input)
- self.assertTrue(fsdp_model._register_post_backward_hooks.called)
- self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
+ with mock.patch(
+ "torch.distributed.fsdp._runtime_utils._register_post_backward_hooks"
+ ) as register_post_bwd_mock:
+ self.assertFalse(fsdp_model._register_pre_backward_hooks.called)
+ self.assertFalse(register_post_bwd_mock.called)
+ fsdp_model(*input)
+ self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
+ self.assertTrue(register_post_bwd_mock.called)
class TestNoGrad(FSDPTest):
diff --git a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py b/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py
deleted file mode 100644
index a1c73d1..0000000
--- a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Owner(s): ["oncall: distributed"]
-
-from typing import Any, Callable
-
-import torch
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.fsdp._symbolic_trace import TracingConfig
-from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
-from torch.distributed.fsdp.wrap import always_wrap_policy, ParamExecOrderWrapPolicy
-from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
-from torch.testing._internal.common_fsdp import FSDPTest
-from torch.testing._internal.common_utils import (
- instantiate_parametrized_tests,
- parametrize,
- run_tests,
-)
-
-
-class Model(torch.nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.layer0 = torch.nn.Linear(6, 6)
- self.layer1 = torch.nn.Linear(6, 6, bias=False)
- self.layer2 = torch.nn.Sequential(
- torch.nn.Linear(6, 3, bias=False),
- torch.nn.ReLU(),
- torch.nn.Linear(3, 6, bias=False),
- )
- self.relu = torch.nn.ReLU()
-
- def forward(self, x: Any, use_all_params: bool = True):
- # `layer0` -> `layer2` -> `layer1`
- # the forward execution order is NOT consistent with the model definition order.
- z = self.relu(self.layer0(x))
- z = self.relu(self.layer2(z))
- if use_all_params:
- z = self.relu(self.layer1(z))
- return z
-
- def get_input(self, device: torch.device):
- return (torch.randn((8, 6)).to(device),)
-
- def get_loss(self, input, output):
- return (output - input[0]).sum()
-
- @staticmethod
- def wrap(
- sharding_strategy: ShardingStrategy,
- device: torch.device,
- wrap_policy: Callable,
- ) -> torch.nn.Module:
- model = Model()
- fsdp_model = FSDP(
- model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy
- )
- return fsdp_model.to(device)
-
-
-class TestFSDPExecOrder(FSDPTest):
- @property
- def device(self):
- return torch.device("cuda")
-
- @skip_if_lt_x_gpu(2)
- @parametrize(
- "sharding_strategy",
- [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
- )
- def test_fsdp_flatten_params_exec_order(
- self,
- sharding_strategy: ShardingStrategy,
- ):
- """
- Test ``_fsdp_params_exec_order`` with ``ParamExecOrderWrapPolicy``,
- after running one iteration of forward and backward pass.
- Here ``torch.fx`` is not enabled inside ``ParamExecOrderWrapPolicy``.
- """
- wrap_policy = ParamExecOrderWrapPolicy(init_policy=always_wrap_policy)
- fsdp_model = Model.wrap(sharding_strategy, self.device, wrap_policy=wrap_policy)
- self.assertTrue(fsdp_model._is_param_exec_order_prep_stage())
- # run one iteration to record the execution ordering
- input = fsdp_model.module.get_input(self.device)
- output = fsdp_model(*input)
- loss = fsdp_model.module.get_loss(input, output).to(self.device)
- loss.backward()
- params_list = list(fsdp_model.parameters())
- # Since the forward execution order is NOT consistent with
- # the model definition order, the ordering in flatten_named_params_exec_order
- # should be different from named_parameters.
- self.assertEqual(
- fsdp_model._fsdp_params_exec_order,
- [params_list[0], params_list[2], params_list[3], params_list[1]],
- )
- self.assertTrue(fsdp_model._use_param_exec_order_policy())
- self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage())
-
- @skip_if_lt_x_gpu(2)
- @parametrize(
- "sharding_strategy",
- [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
- )
- def test_fsdp_flatten_params_exec_order_symbolic_trace(
- self,
- sharding_strategy: ShardingStrategy,
- ):
- """
- Tests ``ParamExecOrderWrapPolicy`` with symbolic tracing.
- With symbolic tracing enabled, ``_is_param_exec_order_prep_stage``
- should always set as False.
- """
- wrap_policy = ParamExecOrderWrapPolicy(
- init_policy=always_wrap_policy,
- tracing_config=TracingConfig(concrete_args={"use_all_params": False}),
- )
- fsdp_model = Model.wrap(
- sharding_strategy,
- self.device,
- wrap_policy=wrap_policy,
- )
- params_list = list(fsdp_model.parameters())
- # Since the forward execution order is NOT consistent with the model definition order,
- # the ordering in flatten_named_params_exec_order should be different from named_parameters
- self.assertEqual(
- fsdp_model._fsdp_params_exec_order,
- [params_list[0], params_list[2], params_list[3]],
- )
- self.assertTrue(fsdp_model._use_param_exec_order_policy())
- self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage())
-
-
-instantiate_parametrized_tests(TestFSDPExecOrder)
-
-if __name__ == "__main__":
- run_tests()
diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py
index 6d3681f..34f45fc 100644
--- a/torch/distributed/fsdp/_common_utils.py
+++ b/torch/distributed/fsdp/_common_utils.py
@@ -2,8 +2,9 @@
This file includes private common utilities for FSDP.
"""
+import traceback
from enum import auto, Enum
-from typing import Callable, Dict, List, Union
+from typing import Callable, Dict, List, no_type_check, Union
import torch
import torch.distributed.fsdp.flat_param as flat_param_file
@@ -153,3 +154,25 @@
f(root_module, "", *args, **kwargs)
return return_fn(*args, **kwargs)
+
+
+@no_type_check
+def _assert_in_training_states(
+ state: _State,
+ training_states: List[TrainingState],
+) -> None:
+ """Asserts that FSDP is in the states ``_training_states``."""
+ # Raise a `ValueError` instead of using `assert` to ensure that these
+ # logical assertions run even if `assert`s are disabled
+ if state.training_state not in training_states:
+ msg = (
+ f"expected to be in states {training_states} but current state is "
+ f"{state.training_state}"
+ )
+ # Print the error on rank 0 in case this is called in the backward pass
+ if state.rank == 0:
+ if isinstance(state, nn.Module):
+ print(f"Asserting FSDP instance is: {state}")
+ print(f"ERROR: {msg}")
+ traceback.print_stack()
+ raise ValueError(msg)
diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py
index 4311e27..c79dfe6 100644
--- a/torch/distributed/fsdp/_runtime_utils.py
+++ b/torch/distributed/fsdp/_runtime_utils.py
@@ -1,9 +1,24 @@
-from typing import Any, List, no_type_check, Optional, Tuple
+import functools
+from typing import Any, Callable, List, no_type_check, Optional, Tuple
import torch
-from torch.distributed.fsdp._common_utils import _State
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
+from torch.distributed.fsdp._common_utils import (
+ _assert_in_training_states,
+ _State,
+ TrainingState,
+)
from torch.distributed.fsdp._utils import _apply_to_tensors, p_assert
-from torch.distributed.fsdp.flat_param import FlatParamHandle
+from torch.distributed.fsdp.api import BackwardPrefetch
+from torch.distributed.fsdp.flat_param import (
+ _HandlesKey,
+ FlatParameter,
+ FlatParamHandle,
+ HandleShardingStrategy,
+ HandleTrainingState,
+)
from torch.distributed.utils import _to_kwargs
@@ -91,6 +106,376 @@
handle.reshard_grad()
+@no_type_check
+def _pre_forward(
+ state: _State,
+ handles: List[FlatParamHandle],
+ unshard_fn: Callable,
+ module: nn.Module,
+ input: Any,
+):
+ """
+ Runs the pre-forward logic. This includes an opportunity to unshard
+ currently sharded parameters such as those for the current forward and
+ registering post-backward hooks for these current parameters.
+
+ Args:
+ handles (List[FlatParamHandle]): Handles giving the parameters used in
+ the current forward.
+ unshard_fn (Optional[Callable]): A callable to unshard any currently
+ sharded parameters or ``None`` to not do any unsharding.
+ module (nn.Module): Module whose forward this method runs right before.
+ input (Any): Unused; expected by the hook signature.
+ """
+ state.training_state = TrainingState.FORWARD_BACKWARD
+ state._exec_order_data.record_pre_forward(handles, module.training)
+ for handle in handles:
+ handle._training_state = HandleTrainingState.FORWARD
+ if unshard_fn is not None:
+ unshard_fn()
+ # Register post-backward hooks to reshard the parameters and reduce-scatter
+ # their gradients. They must be re-registered every forward pass in case
+ # the `grad_fn` is mutated.
+ _register_post_backward_hooks(state, handles)
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_hook(
+ state: _State,
+ handle: FlatParamHandle,
+ *unused: Any,
+):
+ """
+ Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
+
+ Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
+ unsharded gradient for the local batch.
+
+ Postcondition:
+ - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
+ unsharded gradient.
+ - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
+ gradient (accumulating with any existing gradient).
+ """
+ param = handle.flat_param
+ param._post_backward_called = True
+ with torch.autograd.profiler.record_function(
+ "FullyShardedDataParallel._post_backward_hook"
+ ):
+ _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
+ state.training_state = TrainingState.FORWARD_BACKWARD
+ p_assert(
+ handle._training_state == HandleTrainingState.BACKWARD_PRE,
+ f"Expects `BACKWARD_PRE` state but got {handle._training_state}",
+ )
+ handle._training_state = HandleTrainingState.BACKWARD_POST
+
+ if param.grad is None:
+ return
+ if param.grad.requires_grad:
+ raise RuntimeError("FSDP does not support gradients of gradients")
+
+ free_unsharded_flat_param = _should_free_in_backward(state, handle)
+ _reshard(state, [handle], [free_unsharded_flat_param])
+
+ # TODO: Post-backward prefetching does not support the multiple handles
+ # per module case since the post-backward hook runs per handle, not per
+ # group of handles.
+ handles_key = (handle,)
+ _prefetch_handles(state, handles_key)
+
+ if not state._sync_gradients:
+ return
+
+ # Wait for all ops in the current stream (e.g. gradient
+ # computation) to finish before reduce-scattering the gradient
+ state._streams["post_backward"].wait_stream(torch.cuda.current_stream())
+
+ with torch.cuda.stream(state._streams["post_backward"]):
+ unsharded_grad_data = param.grad.data
+ if state._exec_order_data.is_first_iter: # only check once
+ _check_comm_hook(
+ state._communication_hook, state._communication_hook_state
+ )
+ if handle._uses_reduce_mixed_precision and not _low_precision_hook_enabled(
+ state
+ ):
+ # TODO: Use the low precision communication hook directly
+ param.grad.data = param.grad.to(state.mixed_precision.reduce_dtype)
+
+ if handle.uses_sharded_strategy:
+ # We clear `.grad` to permit multiple backwards. This avoids a
+ # race where the second backward pass computation precedes
+ # ahead of the first backward pass reduction, which is possible
+ # since the reduction is issued in a separate stream and is
+ # async and would result in reducing the wrong gradient.
+ unsharded_grad = param.grad.data
+ param.grad = None
+ p_assert(
+ len(unsharded_grad.size()) == 1,
+ f"Expects gradient to be flattened but got {unsharded_grad.size()}",
+ )
+ chunks = list(unsharded_grad.chunk(state.world_size))
+ numel_to_pad = (
+ state.world_size * chunks[0].numel() - unsharded_grad.numel()
+ )
+ padded_unsharded_grad = F.pad(unsharded_grad, [0, numel_to_pad])
+ new_sharded_grad = torch.zeros_like(chunks[0]) # padded
+ state._communication_hook(
+ state._communication_hook_state,
+ padded_unsharded_grad,
+ new_sharded_grad,
+ )
+ _cast_grad_to_param_dtype(state, handle, new_sharded_grad, param)
+
+ # Save the sharded gradient in `_saved_grad_shard` to support
+ # gradient accumulation -- for multiple backwards, the gradient
+ # reductions may happen in arbitrary order
+ accumulate_grad = hasattr(param, "_saved_grad_shard")
+ if accumulate_grad:
+ _check_grad_to_accumulate(new_sharded_grad, param._saved_grad_shard)
+ param._saved_grad_shard += new_sharded_grad
+ else:
+ param._saved_grad_shard = new_sharded_grad
+ sharded_grad = param._saved_grad_shard
+ else:
+ state._communication_hook(state._communication_hook_state, param.grad)
+ # For `NO_SHARD`, we can keep the low precision gradients by
+ # simply omitting the cast altogether
+ if not handle._keep_low_precision_grads:
+ _cast_grad_to_param_dtype(state, handle, param.grad, param)
+ sharded_grad = param.grad.data
+
+ if handle._config.offload_params:
+ # Offload the gradient to CPU to ensure parameters and
+ # gradients are on the same device as required by the optimizer
+ param._cpu_grad.copy_( # type: ignore[attr-defined]
+ sharded_grad.detach(), non_blocking=True
+ ) # synchronized in the post-backward callback
+ # Since the sharded gradient is produced in the post-backward
+ # stream and consumed later in the computation stream, inform
+ # the caching allocator
+ sharded_grad.data.record_stream(torch.cuda.current_stream())
+
+ # Since the unsharded gradient is produced in the computation
+ # stream and consumed in the post-backward stream, inform the
+ # caching allocator (before it goes out of scope)
+ unsharded_grad_data.record_stream(state._streams["post_backward"])
+
+ if handle._use_orig_params:
+ # Since the handle's `FlatParameter` completed its gradient
+ # computation, we should reset the gradient noneness mask
+ handle._reset_is_grad_none()
+ # Delay using sharded gradient views until after the
+ # reduce-scatter instead of immediately after resharding
+ handle._use_sharded_grad_views()
+
+
+@no_type_check
+def _should_free_in_backward(
+ state: _State,
+ handle: FlatParamHandle,
+) -> bool:
+ """
+ Returns whether FSDP should free the unsharded flattened parameter in the
+ post-backward or not.
+ """
+ return (
+ state._sync_gradients and handle.uses_sharded_strategy
+ ) or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
+
+
+@no_type_check
+def _cast_grad_to_param_dtype(
+ state: _State,
+ handle: FlatParamHandle,
+ sharded_grad: torch.Tensor,
+ param: FlatParameter,
+):
+ """
+ Casts ``sharded_grad`` back to the full parameter dtype so that the
+ optimizer step runs with that dtype. This performs an actual cast if
+ 1. parameters were in reduced precision during the forward since then
+ gradients would be in that reduced precision, or
+ 2. parameters were not in reduced precision but gradients were in
+ reduced precision for communication.
+ However, if a low precision communication hook is registered, then this
+ dtype cast happens in the hook instead.
+ """
+ _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
+ if not _low_precision_hook_enabled(state) and (
+ handle._uses_param_mixed_precision or handle._uses_reduce_mixed_precision
+ ):
+ low_prec_grad_data = sharded_grad.data
+ sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
+ # Since for `NO_SHARD`, the gradient is produced in the computation
+ # stream and consumed here in the post-backward stream, inform the
+ # caching allocator; for the sharded strategies, the gradient is
+ # produced in the post-backward stream, so this `record_stream()`
+ # should be a no-op
+ low_prec_grad_data.record_stream(torch.cuda.current_stream())
+
+
+def _check_comm_hook(
+ comm_hook: Any,
+ comm_hook_state: Any,
+) -> None:
+ p_assert(comm_hook is not None, "Communication hook should not be `None`")
+ p_assert(
+ comm_hook_state is not None, "Communication hook state should not be `None`"
+ )
+
+
+def _check_grad_to_accumulate(
+ new_sharded_grad: torch.Tensor,
+ accumulated_grad: torch.Tensor,
+) -> None:
+ p_assert(
+ accumulated_grad.shape == new_sharded_grad.shape,
+ "Shape mismatch when accumulating gradients: "
+ f"existing gradient shape={accumulated_grad.shape} "
+ f"new gradient shape={new_sharded_grad.shape}",
+ )
+ p_assert(
+ accumulated_grad.device == new_sharded_grad.device,
+ "Device mismatch when accumulating gradients: "
+ f"existing gradient device={accumulated_grad.device} "
+ f"new gradient device={new_sharded_grad.device}",
+ )
+
+
+@no_type_check
+def _low_precision_hook_enabled(state: _State) -> bool:
+ return state._communication_hook in LOW_PRECISION_HOOKS
+
+
+@no_type_check
+def _prefetch_handles(
+ state: _State,
+ current_handles_key: _HandlesKey,
+) -> None:
+ """
+ Prefetches the next handles if needed (without synchronization). An empty
+ handles key cannot prefetch.
+ """
+ if not current_handles_key:
+ return
+ handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
+ for handles_key in handles_to_prefetch:
+ # Prefetch the next set of handles without synchronizing to allow
+ # the sync to happen as late as possible to maximize overlap
+ _unshard(
+ state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
+ )
+ state._handles_prefetched[handles_key] = True
+
+
+@no_type_check
+def _get_handles_to_prefetch(
+ state: _State,
+ current_handles_key: _HandlesKey,
+) -> List[_HandlesKey]:
+ """
+ Returns a :class:`list` of the handles keys to prefetch for the next
+ module(s), where ``current_handles_key`` represents the current module.
+
+ "Prefetching" refers to running the unshard logic early (without
+ synchronization), and the "next" modules depend on the recorded execution
+ order and the current training state.
+ """
+ training_state = _get_training_state(current_handles_key)
+ valid_training_states = (
+ HandleTrainingState.BACKWARD_PRE,
+ HandleTrainingState.BACKWARD_POST,
+ HandleTrainingState.FORWARD,
+ )
+ p_assert(
+ training_state in valid_training_states,
+ f"Prefetching is only supported in {valid_training_states} but "
+ f"currently in {training_state}",
+ )
+ eod = state._exec_order_data
+ target_handles_keys: List[_HandlesKey] = []
+ if (
+ training_state == HandleTrainingState.BACKWARD_PRE
+ and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
+ ) or (
+ training_state == HandleTrainingState.BACKWARD_POST
+ and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
+ ):
+ target_handles_keys = [
+ target_handles_key
+ for target_handles_key in eod.get_handles_to_backward_prefetch(
+ current_handles_key
+ )
+ if state._needs_pre_backward_unshard.get(target_handles_key, False)
+ and not state._handles_prefetched.get(target_handles_key, False)
+ ]
+ elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
+ target_handles_keys = [
+ target_handles_key
+ for target_handles_key in eod.get_handles_to_forward_prefetch(
+ current_handles_key
+ )
+ if state._needs_pre_forward_unshard.get(target_handles_key, False)
+ and not state._handles_prefetched.get(target_handles_key, False)
+ ]
+ return target_handles_keys
+
+
+def _get_training_state(
+ handles_key: _HandlesKey,
+) -> HandleTrainingState:
+ """Returns the training state of the handles in ``handles_key``."""
+ p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
+ training_states = set(handle._training_state for handle in handles_key)
+ p_assert(
+ len(training_states) == 1,
+ f"Expects uniform training state but got {training_states}",
+ )
+ return next(iter(training_states))
+
+
+def _register_post_backward_hooks(
+ state: _State,
+ handles: List[FlatParamHandle],
+) -> None:
+ """
+ Registers post-backward hooks on the ``FlatParameter`` s'
+ ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
+
+ The ``AccumulateGrad`` object represents the last function that finalizes
+ the ``FlatParameter`` 's gradient, so it only runs after its entire
+ gradient computation has finished.
+
+ We register the post-backward hook only once in the *first* forward that a
+ ``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
+ object being preserved through multiple forwards.
+ """
+ # If there is no gradient computation, then there is no need for
+ # post-backward logic
+ if not torch.is_grad_enabled():
+ return
+ for handle in handles:
+ flat_param = handle.flat_param
+ already_registered = hasattr(flat_param, "_post_backward_hook_state")
+ if already_registered or not flat_param.requires_grad:
+ continue
+ # Get the `AccumulateGrad` object
+ temp_flat_param = flat_param.expand_as(flat_param)
+ p_assert(
+ temp_flat_param.grad_fn is not None,
+ "The `grad_fn` is needed to access the `AccumulateGrad` and "
+ "register the post-backward hook",
+ )
+ acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
+ hook_handle = acc_grad.register_hook(
+ functools.partial(_post_backward_hook, state, handle)
+ )
+ flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]
+
+
def _wait_for_computation_stream(
computation_stream: torch.cuda.Stream,
unshard_stream: torch.cuda.Stream,
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 5fd1308..50f3d5f 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -106,7 +106,7 @@
offload_params: bool
low_prec_param_dtype: Optional[torch.dtype]
low_prec_reduce_dtype: Optional[torch.dtype]
- keep_low_precision_grads: Optional[bool] = False
+ keep_low_precision_grads: bool = False
class FlatParameter(nn.Parameter):
@@ -1801,6 +1801,14 @@
return self._config.low_prec_param_dtype is not None
@property
+ def _uses_reduce_mixed_precision(self) -> bool:
+ return self._config.low_prec_reduce_dtype is not None
+
+ @property
+ def _keep_low_precision_grads(self) -> bool:
+ return self._config.keep_low_precision_grads
+
+ @property
def _force_full_precision(self) -> bool:
return (
self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 412a7e0..3d92a96 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -24,7 +24,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
-import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributed import ProcessGroup
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -54,9 +53,12 @@
)
from torch.distributed.fsdp._runtime_utils import (
_clear_grads_if_needed,
+ _pre_forward,
+ _prefetch_handles,
_prepare_forward_inputs,
_reshard,
_reshard_grads,
+ _should_free_in_backward,
_unshard,
_unshard_grads,
_wait_for_computation_stream,
@@ -87,12 +89,7 @@
_pre_load_state_dict_hook,
)
from ._utils import _apply_to_tensors, _free_storage, p_assert
-from .flat_param import (
- _HandlesKey,
- FlatParameter,
- FlatParamHandle,
- HandleShardingStrategy,
-)
+from .flat_param import FlatParameter, FlatParamHandle, HandleShardingStrategy
from .wrap import ParamExecOrderWrapPolicy
_TORCH_FX_AVAIL = True
@@ -926,92 +923,6 @@
# CPU offloading (H2D copy) and mixed precision (low precision cast).
self._streams["pre_unshard"] = torch.cuda.Stream()
- def _prefetch_handles(
- self,
- current_handles_key: _HandlesKey,
- ) -> None:
- """
- Prefetches the next handles if needed (without synchronization). An
- empty handles key cannot prefetch.
- """
- if not current_handles_key:
- return
- handles_to_prefetch = self._get_handles_to_prefetch(current_handles_key)
- for handles_key in handles_to_prefetch:
- # Prefetch the next set of handles without synchronizing to allow
- # the sync to happen as late as possible to maximize overlap
- _unshard(
- self,
- handles_key,
- self._streams["unshard"],
- self._streams["pre_unshard"],
- )
- self._handles_prefetched[handles_key] = True
-
- def _get_handles_to_prefetch(
- self,
- current_handles_key: _HandlesKey,
- ) -> List[_HandlesKey]:
- """
- Returns a :class:`list` of the handles keys to prefetch for the next
- module(s), where ``current_handles_key`` represents the current module.
-
- "Prefetching" refers to running the unshard logic early (without
- synchronization), and the "next" modules depend on the recorded
- execution order and the current training state.
- """
- training_state = self._get_training_state(current_handles_key)
- valid_training_states = (
- HandleTrainingState.BACKWARD_PRE,
- HandleTrainingState.BACKWARD_POST,
- HandleTrainingState.FORWARD,
- )
- p_assert(
- training_state in valid_training_states,
- f"Prefetching is only supported in {valid_training_states} but "
- f"currently in {training_state}",
- )
- eod = self._exec_order_data
- target_handles_keys: List[_HandlesKey] = []
- if (
- training_state == HandleTrainingState.BACKWARD_PRE
- and self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
- ) or (
- training_state == HandleTrainingState.BACKWARD_POST
- and self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
- ):
- target_handles_keys = [
- target_handles_key
- for target_handles_key in eod.get_handles_to_backward_prefetch(
- current_handles_key
- )
- if self._needs_pre_backward_unshard.get(target_handles_key, False)
- and not self._handles_prefetched.get(target_handles_key, False)
- ]
- elif training_state == HandleTrainingState.FORWARD and self.forward_prefetch:
- target_handles_keys = [
- target_handles_key
- for target_handles_key in eod.get_handles_to_forward_prefetch(
- current_handles_key
- )
- if self._needs_pre_forward_unshard.get(target_handles_key, False)
- and not self._handles_prefetched.get(target_handles_key, False)
- ]
- return target_handles_keys
-
- def _get_training_state(
- self,
- handles_key: _HandlesKey,
- ) -> HandleTrainingState:
- """Returns the training state of the handles in ``handles_key``."""
- p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
- training_states = set(handle._training_state for handle in handles_key)
- p_assert(
- len(training_states) == 1,
- f"Expects uniform training state but got {training_states}",
- )
- return next(iter(training_states))
-
@staticmethod
def set_state_dict_type(
module: nn.Module,
@@ -1291,7 +1202,9 @@
self._handles,
free_unsharded_flat_params,
)
- self._pre_forward(self._handles, unshard_fn, unused, unused)
+ _pre_forward(
+ self, self._handles, unshard_fn, self._fsdp_wrapped_module, unused
+ )
for handle in self._handles:
p_assert(
handle.flat_param.device == self.compute_device,
@@ -1301,38 +1214,6 @@
output = self._fsdp_wrapped_module(*args, **kwargs)
return self._post_forward(self._handles, reshard_fn, unused, unused, output)
- def _pre_forward(
- self,
- handles: List[FlatParamHandle],
- unshard_fn: Optional[Callable],
- module: nn.Module,
- input: Any,
- ):
- """
- Runs the pre-forward logic. This includes an opportunity to unshard
- currently sharded parameters such as those for the current forward and
- registering post-backward hooks for these current parameters.
-
- Args:
- handles (List[FlatParamHandle]): Handles giving the parameters
- used in the current forward.
- unshard_fn (Optional[Callable]): A callable to unshard any
- currently sharded parameters or ``None`` to not do any
- unsharding.
- module (nn.Module): Unused; expected by the hook signature.
- input (Any): Unused; expected by the hook signature.
- """
- self.training_state = TrainingState.FORWARD_BACKWARD
- self._exec_order_data.record_pre_forward(handles, self.training)
- for handle in handles:
- handle._training_state = HandleTrainingState.FORWARD
- if unshard_fn is not None:
- unshard_fn()
- # Register post-backward hooks to reshard the parameters and
- # reduce-scatter their gradients. They must be re-registered every
- # forward pass in case the `grad_fn` is mutated.
- self._register_post_backward_hooks(handles)
-
def _pre_forward_unshard(
self,
handles: List[FlatParamHandle],
@@ -1345,7 +1226,7 @@
handles_key = tuple(handles)
self._needs_pre_forward_unshard[handles_key] = False
torch.cuda.current_stream().wait_stream(self._streams["unshard"])
- self._prefetch_handles(handles_key)
+ _prefetch_handles(self, handles_key)
def _post_forward(
self,
@@ -1891,7 +1772,7 @@
# Set this to `False` to ensure that a mistargeted prefetch
# does not actually unshard these handles
self._needs_pre_backward_unshard[_handles_key] = False
- self._prefetch_handles(_handles_key)
+ _prefetch_handles(self, _handles_key)
for handle in _handles:
handle.prepare_gradient_for_backward()
self._ran_pre_backward_hook[_handles_key] = True
@@ -1904,261 +1785,6 @@
return _apply_to_tensors(_register_hook, outputs)
- def _register_post_backward_hooks(
- self,
- handles: List[FlatParamHandle],
- ) -> None:
- """
- Registers post-backward hooks on the ``FlatParameter`` s'
- ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
-
- The ``AccumulateGrad`` object represents the last function that
- finalizes the ``FlatParameter`` 's gradient, so it only runs after its
- entire gradient computation has finished.
-
- We register the post-backward hook only once in the *first* forward
- that a ``FlatParameter`` participates in. This relies on the
- ``AccumulateGrad`` object being preserved through multiple forwards.
- """
- # If there is no gradient computation, then there is no need for
- # post-backward logic
- if not torch.is_grad_enabled():
- return
- for handle in handles:
- flat_param = handle.flat_param
- already_registered = hasattr(flat_param, "_post_backward_hook_state")
- if already_registered or not flat_param.requires_grad:
- continue
- # Get the `AccumulateGrad` object
- temp_flat_param = flat_param.expand_as(flat_param)
- p_assert(
- temp_flat_param.grad_fn is not None,
- "The `grad_fn` is needed to access the `AccumulateGrad` and "
- "register the post-backward hook",
- )
- acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
- hook_handle = acc_grad.register_hook(
- functools.partial(self._post_backward_hook, handle)
- )
- flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]
-
- @torch.no_grad()
- def _post_backward_hook(
- self,
- handle: FlatParamHandle,
- *unused: Any,
- ) -> None:
- """
- Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
-
- Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
- unsharded gradient for the local batch.
-
- Postcondition:
- - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
- unsharded gradient.
- - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
- gradient (accumulating with any existing gradient).
- """
- param = handle.flat_param
- param._post_backward_called = True
- with torch.autograd.profiler.record_function(
- "FullyShardedDataParallel._post_backward_hook"
- ):
- self._assert_state([TrainingState.FORWARD_BACKWARD])
- self.training_state = TrainingState.FORWARD_BACKWARD
- p_assert(
- handle._training_state == HandleTrainingState.BACKWARD_PRE,
- f"Expects `BACKWARD_PRE` state but got {handle._training_state}",
- )
- handle._training_state = HandleTrainingState.BACKWARD_POST
-
- if (
- self._use_param_exec_order_policy()
- and self._param_exec_order_prep_stage
- ):
- # In self._fsdp_params_exec_order, the parameters are ordered based on
- # the execution order in the backward pass in the first iteration.
- self._fsdp_params_exec_order.append(param)
-
- if param.grad is None:
- return
- if param.grad.requires_grad:
- raise RuntimeError(
- "FSDP only works with gradients that don't require gradients"
- )
-
- free_unsharded_flat_param = self._should_free_unsharded_flat_param(handle)
- _reshard(self, [handle], [free_unsharded_flat_param])
-
- # TODO (awgu): Post-backward prefetching does not support the
- # multiple handles per module case (which was why we keyed by
- # *tuple*). The post-backward hook runs per handle, not per group
- # of handles. To generalize this, we may need a 2-level mapping,
- # where we map each individual handle to its groups of handles and
- # then from the groups of handles to their indices in the order.
- handles_key = (handle,)
- self._prefetch_handles(handles_key)
-
- if not self._sync_gradients:
- return
-
- # Wait for all ops in the current stream (e.g. gradient
- # computation) to finish before reduce-scattering the gradient
- self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
-
- with torch.cuda.stream(self._streams["post_backward"]):
- orig_grad_data = param.grad.data
- if (
- self._mixed_precision_enabled_for_reduce()
- and not self._low_precision_hook_enabled()
- ):
- # Cast gradient to precision in which it should be communicated.
- # If a low precision hook is registered and reduce_dtype is specified
- # in `MixedPrecision`, communication hook will take care of
- # casting to lower precision and back.
- # TODO: Make this a communication hook when communication hooks
- # are implemented for FSDP. Note that this is a noop if the
- # reduce_dtype matches the param dtype.
- param.grad.data = param.grad.data.to(
- self.mixed_precision.reduce_dtype
- )
-
- if self._exec_order_data.is_first_iter:
- # For all sharding strategies communication is performed through `_communication_hook`:
- # default comm hooks are: `reduce_scatter` for sharded strategies and
- # `all_reduce` for non-sharded strategies. This checks asserts that `_communication_hook`
- # and `_communication_hook_state`, required for communication not `None`.`
- p_assert(
- self._communication_hook is not None,
- "Communication hook should not be None",
- )
- p_assert(
- self._communication_hook_state is not None,
- "Communication hook state should not be None",
- )
- grad = param.grad.data
- if handle.uses_sharded_strategy:
- # We clear `param.grad` to permit repeated gradient
- # computations when this FSDP module is called multiple times.
- # This is to avoid a race among multiple re-entrant backward
- # passes. For example, the second backward pass computation
- # precedes ahead of the first backward pass reduction, which is
- # possible since the reduction is in a different stream and is
- # async. Then, the first backward pass may be incorrectly
- # reducing the second backward pass's `param.grad`.
- # The reduced gradients are accumulated in
- # `param._saved_grad_shard`, and the gradient reductions can
- # happen in arbitrary order, though we tolerate this due to the
- # (approximate) commutativity of floating-point addition.
- param.grad = None
- grad_flatten = torch.flatten(grad)
- chunks = list(grad_flatten.chunk(self.world_size))
- num_pad = self.world_size * chunks[0].numel() - grad.numel()
- input_flattened = F.pad(grad_flatten, [0, num_pad])
- output = torch.zeros_like(chunks[0])
- self._communication_hook(
- self._communication_hook_state, input_flattened, output
- )
-
- self._cast_grad_to_param_dtype(output, param)
-
- # To support gradient accumulation outside `no_sync()`, we save
- # the gradient data to `param._saved_grad_shard` before the
- # backward pass, accumulate gradients into it here, and set
- # `param.grad` with the accumulated value at the end of the
- # backward pass in preparation for the optimizer step.
- accumulate_grad = hasattr(param, "_saved_grad_shard")
- if accumulate_grad:
- p_assert(
- param._saved_grad_shard.shape == output.shape, # type: ignore[attr-defined]
- "Shape mismatch when accumulating gradients: " # type: ignore[attr-defined]
- f"existing grad shape={param._saved_grad_shard.shape} "
- f"new grad shape={output.shape}", # type: ignore[attr-defined]
- )
- p_assert(
- param._saved_grad_shard.device == output.device, # type: ignore[attr-defined]
- "Device mismatch when accumulating gradients: " # type: ignore[attr-defined]
- f"existing grad device={param._saved_grad_shard.device} "
- f"new grad device={output.device}", # type: ignore[attr-defined]
- )
- param._saved_grad_shard += output # type: ignore[attr-defined]
- else:
- param._saved_grad_shard = output # type: ignore[attr-defined]
- grad = param._saved_grad_shard # type: ignore[attr-defined]
- else:
- if self.sharding_strategy == ShardingStrategy.NO_SHARD:
- self._communication_hook(
- self._communication_hook_state, param.grad
- )
-
- # For NO_SHARD keeping grads in the reduced precision, we
- # can simply omit the cast as needed, we can't do this for
- # other sharding strategies because grad field is assigned
- # in _finalize_params. TODO (rvarm1) this divergence in
- # logic is not ideal.
- if not self._mixed_precision_keep_low_precision_grads():
- self._cast_grad_to_param_dtype(param.grad, param)
-
- # Regardless of sharding or not, offload the grad to CPU if we are
- # offloading params. This is so param and grad reside on same device
- # which is needed for the optimizer step.
- if handle._config.offload_params:
- # We specify non_blocking=True
- # and ensure the appropriate synchronization is done by waiting
- # streams in _wait_for_post_backward.
- param._cpu_grad.copy_( # type: ignore[attr-defined]
- grad.detach(), non_blocking=True
- )
- # Don't let this memory get reused until after the transfer.
- grad.data.record_stream(torch.cuda.current_stream())
-
- # After _post_backward_hook returns, orig_grad_data will eventually
- # go out of scope, at which point it could otherwise be freed for
- # further reuse by the main stream while the div/reduce_scatter/copy
- # are underway in the post_backward stream. See:
- # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
- orig_grad_data.record_stream(self._streams["post_backward"])
-
- if handle._use_orig_params:
- # Since the handle's `FlatParameter` completed its gradient
- # computation, we should reset the gradient noneness mask
- handle._reset_is_grad_none()
- # Delay using sharded gradient views until after the
- # reduce-scatter instead of immediately after resharding
- handle._use_sharded_grad_views()
-
- def _cast_grad_to_param_dtype(
- self,
- grad: torch.Tensor,
- param: FlatParameter,
- ):
- """
- Casts gradient ``grad`` back to the full parameter dtype so that the
- optimizer step runs with that dtype. This performs an actual cast if
- 1. parameters were in reduced precision during the forward since then
- gradients would be in that reduced precision, or
- 2. parameters were not in reduced precision but gradients were in
- reduced precision for communication.
- However, if a low precision communication hook is registered, then this
- dtype cast happens in the hook instead.
- """
- self._assert_state(TrainingState.FORWARD_BACKWARD)
- if not self._low_precision_hook_enabled() and (
- self._mixed_precision_enabled_for_params()
- or self._mixed_precision_enabled_for_reduce()
- ):
- low_prec_grad_data = grad.data
- grad.data = grad.data.to(dtype=param.dtype)
- # Do not let the low precision gradient memory get reused until
- # the cast to full parameter precision completes
- low_prec_grad_data.record_stream(torch.cuda.current_stream())
-
- def _should_free_unsharded_flat_param(self, handle: FlatParamHandle):
- return (
- self._sync_gradients and handle.uses_sharded_strategy
- ) or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
-
def _queue_wait_for_post_backward(self) -> None:
"""
Queues a post-backward callback from the root FSDP instance, which
@@ -2220,7 +1846,7 @@
if already_resharded:
continue
free_unsharded_flat_params.append(
- self._should_free_unsharded_flat_param(handle)
+ _should_free_in_backward(fsdp_module, handle)
)
handles_to_reshard.append(handle)
_reshard(self, handles_to_reshard, free_unsharded_flat_params)