[FSDP()][14/N] Refactor pre-forward/post-backward (#87927)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87927
Approved by: https://github.com/mrshenli
diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py
index 93d5e4f..b0d5252 100644
--- a/test/distributed/fsdp/test_fsdp_core.py
+++ b/test/distributed/fsdp/test_fsdp_core.py
@@ -366,12 +366,14 @@
         )
         input = fsdp_model.module.get_input(torch.device("cuda"))
         fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None)
-        fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None)
-        self.assertFalse(fsdp_model._register_post_backward_hooks.called)
-        self.assertFalse(fsdp_model._register_pre_backward_hooks.called)
-        fsdp_model(*input)
-        self.assertTrue(fsdp_model._register_post_backward_hooks.called)
-        self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
+        with mock.patch(
+            "torch.distributed.fsdp._runtime_utils._register_post_backward_hooks"
+        ) as register_post_bwd_mock:
+            self.assertFalse(fsdp_model._register_pre_backward_hooks.called)
+            self.assertFalse(register_post_bwd_mock.called)
+            fsdp_model(*input)
+            self.assertTrue(fsdp_model._register_pre_backward_hooks.called)
+            self.assertTrue(register_post_bwd_mock.called)
 
 
 class TestNoGrad(FSDPTest):
diff --git a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py b/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py
deleted file mode 100644
index a1c73d1..0000000
--- a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Owner(s): ["oncall: distributed"]
-
-from typing import Any, Callable
-
-import torch
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.fsdp._symbolic_trace import TracingConfig
-from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
-from torch.distributed.fsdp.wrap import always_wrap_policy, ParamExecOrderWrapPolicy
-from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
-from torch.testing._internal.common_fsdp import FSDPTest
-from torch.testing._internal.common_utils import (
-    instantiate_parametrized_tests,
-    parametrize,
-    run_tests,
-)
-
-
-class Model(torch.nn.Module):
-    def __init__(self) -> None:
-        super().__init__()
-        self.layer0 = torch.nn.Linear(6, 6)
-        self.layer1 = torch.nn.Linear(6, 6, bias=False)
-        self.layer2 = torch.nn.Sequential(
-            torch.nn.Linear(6, 3, bias=False),
-            torch.nn.ReLU(),
-            torch.nn.Linear(3, 6, bias=False),
-        )
-        self.relu = torch.nn.ReLU()
-
-    def forward(self, x: Any, use_all_params: bool = True):
-        # `layer0` -> `layer2` -> `layer1`
-        # the forward execution order is NOT consistent with the model definition order.
-        z = self.relu(self.layer0(x))
-        z = self.relu(self.layer2(z))
-        if use_all_params:
-            z = self.relu(self.layer1(z))
-        return z
-
-    def get_input(self, device: torch.device):
-        return (torch.randn((8, 6)).to(device),)
-
-    def get_loss(self, input, output):
-        return (output - input[0]).sum()
-
-    @staticmethod
-    def wrap(
-        sharding_strategy: ShardingStrategy,
-        device: torch.device,
-        wrap_policy: Callable,
-    ) -> torch.nn.Module:
-        model = Model()
-        fsdp_model = FSDP(
-            model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy
-        )
-        return fsdp_model.to(device)
-
-
-class TestFSDPExecOrder(FSDPTest):
-    @property
-    def device(self):
-        return torch.device("cuda")
-
-    @skip_if_lt_x_gpu(2)
-    @parametrize(
-        "sharding_strategy",
-        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
-    )
-    def test_fsdp_flatten_params_exec_order(
-        self,
-        sharding_strategy: ShardingStrategy,
-    ):
-        """
-        Test ``_fsdp_params_exec_order`` with ``ParamExecOrderWrapPolicy``,
-        after running one iteration of forward and backward pass.
-        Here ``torch.fx`` is not enabled inside ``ParamExecOrderWrapPolicy``.
-        """
-        wrap_policy = ParamExecOrderWrapPolicy(init_policy=always_wrap_policy)
-        fsdp_model = Model.wrap(sharding_strategy, self.device, wrap_policy=wrap_policy)
-        self.assertTrue(fsdp_model._is_param_exec_order_prep_stage())
-        # run one iteration to record the execution ordering
-        input = fsdp_model.module.get_input(self.device)
-        output = fsdp_model(*input)
-        loss = fsdp_model.module.get_loss(input, output).to(self.device)
-        loss.backward()
-        params_list = list(fsdp_model.parameters())
-        # Since the forward execution order is NOT consistent with
-        # the model definition order, the ordering in flatten_named_params_exec_order
-        # should be different from named_parameters.
-        self.assertEqual(
-            fsdp_model._fsdp_params_exec_order,
-            [params_list[0], params_list[2], params_list[3], params_list[1]],
-        )
-        self.assertTrue(fsdp_model._use_param_exec_order_policy())
-        self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage())
-
-    @skip_if_lt_x_gpu(2)
-    @parametrize(
-        "sharding_strategy",
-        [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP],
-    )
-    def test_fsdp_flatten_params_exec_order_symbolic_trace(
-        self,
-        sharding_strategy: ShardingStrategy,
-    ):
-        """
-        Tests ``ParamExecOrderWrapPolicy`` with symbolic tracing.
-        With symbolic tracing enabled, ``_is_param_exec_order_prep_stage``
-        should always set as False.
-        """
-        wrap_policy = ParamExecOrderWrapPolicy(
-            init_policy=always_wrap_policy,
-            tracing_config=TracingConfig(concrete_args={"use_all_params": False}),
-        )
-        fsdp_model = Model.wrap(
-            sharding_strategy,
-            self.device,
-            wrap_policy=wrap_policy,
-        )
-        params_list = list(fsdp_model.parameters())
-        # Since the forward execution order is NOT consistent with the model definition order,
-        # the ordering in flatten_named_params_exec_order should be different from named_parameters
-        self.assertEqual(
-            fsdp_model._fsdp_params_exec_order,
-            [params_list[0], params_list[2], params_list[3]],
-        )
-        self.assertTrue(fsdp_model._use_param_exec_order_policy())
-        self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage())
-
-
-instantiate_parametrized_tests(TestFSDPExecOrder)
-
-if __name__ == "__main__":
-    run_tests()
diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py
index 6d3681f..34f45fc 100644
--- a/torch/distributed/fsdp/_common_utils.py
+++ b/torch/distributed/fsdp/_common_utils.py
@@ -2,8 +2,9 @@
 This file includes private common utilities for FSDP.
 """
 
+import traceback
 from enum import auto, Enum
-from typing import Callable, Dict, List, Union
+from typing import Callable, Dict, List, no_type_check, Union
 
 import torch
 import torch.distributed.fsdp.flat_param as flat_param_file
@@ -153,3 +154,25 @@
 
     f(root_module, "", *args, **kwargs)
     return return_fn(*args, **kwargs)
+
+
+@no_type_check
+def _assert_in_training_states(
+    state: _State,
+    training_states: List[TrainingState],
+) -> None:
+    """Asserts that FSDP is in the states ``_training_states``."""
+    # Raise a `ValueError` instead of using `assert` to ensure that these
+    # logical assertions run even if `assert`s are disabled
+    if state.training_state not in training_states:
+        msg = (
+            f"expected to be in states {training_states} but current state is "
+            f"{state.training_state}"
+        )
+        # Print the error on rank 0 in case this is called in the backward pass
+        if state.rank == 0:
+            if isinstance(state, nn.Module):
+                print(f"Asserting FSDP instance is: {state}")
+            print(f"ERROR: {msg}")
+            traceback.print_stack()
+        raise ValueError(msg)
diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py
index 4311e27..c79dfe6 100644
--- a/torch/distributed/fsdp/_runtime_utils.py
+++ b/torch/distributed/fsdp/_runtime_utils.py
@@ -1,9 +1,24 @@
-from typing import Any, List, no_type_check, Optional, Tuple
+import functools
+from typing import Any, Callable, List, no_type_check, Optional, Tuple
 
 import torch
-from torch.distributed.fsdp._common_utils import _State
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
+from torch.distributed.fsdp._common_utils import (
+    _assert_in_training_states,
+    _State,
+    TrainingState,
+)
 from torch.distributed.fsdp._utils import _apply_to_tensors, p_assert
-from torch.distributed.fsdp.flat_param import FlatParamHandle
+from torch.distributed.fsdp.api import BackwardPrefetch
+from torch.distributed.fsdp.flat_param import (
+    _HandlesKey,
+    FlatParameter,
+    FlatParamHandle,
+    HandleShardingStrategy,
+    HandleTrainingState,
+)
 from torch.distributed.utils import _to_kwargs
 
 
@@ -91,6 +106,376 @@
         handle.reshard_grad()
 
 
+@no_type_check
+def _pre_forward(
+    state: _State,
+    handles: List[FlatParamHandle],
+    unshard_fn: Callable,
+    module: nn.Module,
+    input: Any,
+):
+    """
+    Runs the pre-forward logic. This includes an opportunity to unshard
+    currently sharded parameters such as those for the current forward and
+    registering post-backward hooks for these current parameters.
+
+    Args:
+        handles (List[FlatParamHandle]): Handles giving the parameters used in
+            the current forward.
+        unshard_fn (Optional[Callable]): A callable to unshard any currently
+            sharded parameters or ``None`` to not do any unsharding.
+        module (nn.Module): Module whose forward this method runs right before.
+        input (Any): Unused; expected by the hook signature.
+    """
+    state.training_state = TrainingState.FORWARD_BACKWARD
+    state._exec_order_data.record_pre_forward(handles, module.training)
+    for handle in handles:
+        handle._training_state = HandleTrainingState.FORWARD
+    if unshard_fn is not None:
+        unshard_fn()
+    # Register post-backward hooks to reshard the parameters and reduce-scatter
+    # their gradients. They must be re-registered every forward pass in case
+    # the `grad_fn` is mutated.
+    _register_post_backward_hooks(state, handles)
+
+
+@no_type_check
+@torch.no_grad()
+def _post_backward_hook(
+    state: _State,
+    handle: FlatParamHandle,
+    *unused: Any,
+):
+    """
+    Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
+
+    Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
+    unsharded gradient for the local batch.
+
+    Postcondition:
+    - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
+    unsharded gradient.
+    - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
+    gradient (accumulating with any existing gradient).
+    """
+    param = handle.flat_param
+    param._post_backward_called = True
+    with torch.autograd.profiler.record_function(
+        "FullyShardedDataParallel._post_backward_hook"
+    ):
+        _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
+        state.training_state = TrainingState.FORWARD_BACKWARD
+        p_assert(
+            handle._training_state == HandleTrainingState.BACKWARD_PRE,
+            f"Expects `BACKWARD_PRE` state but got {handle._training_state}",
+        )
+        handle._training_state = HandleTrainingState.BACKWARD_POST
+
+        if param.grad is None:
+            return
+        if param.grad.requires_grad:
+            raise RuntimeError("FSDP does not support gradients of gradients")
+
+        free_unsharded_flat_param = _should_free_in_backward(state, handle)
+        _reshard(state, [handle], [free_unsharded_flat_param])
+
+        # TODO: Post-backward prefetching does not support the multiple handles
+        # per module case since the post-backward hook runs per handle, not per
+        # group of handles.
+        handles_key = (handle,)
+        _prefetch_handles(state, handles_key)
+
+        if not state._sync_gradients:
+            return
+
+        # Wait for all ops in the current stream (e.g. gradient
+        # computation) to finish before reduce-scattering the gradient
+        state._streams["post_backward"].wait_stream(torch.cuda.current_stream())
+
+        with torch.cuda.stream(state._streams["post_backward"]):
+            unsharded_grad_data = param.grad.data
+            if state._exec_order_data.is_first_iter:  # only check once
+                _check_comm_hook(
+                    state._communication_hook, state._communication_hook_state
+                )
+            if handle._uses_reduce_mixed_precision and not _low_precision_hook_enabled(
+                state
+            ):
+                # TODO: Use the low precision communication hook directly
+                param.grad.data = param.grad.to(state.mixed_precision.reduce_dtype)
+
+            if handle.uses_sharded_strategy:
+                # We clear `.grad` to permit multiple backwards. This avoids a
+                # race where the second backward pass computation precedes
+                # ahead of the first backward pass reduction, which is possible
+                # since the reduction is issued in a separate stream and is
+                # async and would result in reducing the wrong gradient.
+                unsharded_grad = param.grad.data
+                param.grad = None
+                p_assert(
+                    len(unsharded_grad.size()) == 1,
+                    f"Expects gradient to be flattened but got {unsharded_grad.size()}",
+                )
+                chunks = list(unsharded_grad.chunk(state.world_size))
+                numel_to_pad = (
+                    state.world_size * chunks[0].numel() - unsharded_grad.numel()
+                )
+                padded_unsharded_grad = F.pad(unsharded_grad, [0, numel_to_pad])
+                new_sharded_grad = torch.zeros_like(chunks[0])  # padded
+                state._communication_hook(
+                    state._communication_hook_state,
+                    padded_unsharded_grad,
+                    new_sharded_grad,
+                )
+                _cast_grad_to_param_dtype(state, handle, new_sharded_grad, param)
+
+                # Save the sharded gradient in `_saved_grad_shard` to support
+                # gradient accumulation -- for multiple backwards, the gradient
+                # reductions may happen in arbitrary order
+                accumulate_grad = hasattr(param, "_saved_grad_shard")
+                if accumulate_grad:
+                    _check_grad_to_accumulate(new_sharded_grad, param._saved_grad_shard)
+                    param._saved_grad_shard += new_sharded_grad
+                else:
+                    param._saved_grad_shard = new_sharded_grad
+                sharded_grad = param._saved_grad_shard
+            else:
+                state._communication_hook(state._communication_hook_state, param.grad)
+                # For `NO_SHARD`, we can keep the low precision gradients by
+                # simply omitting the cast altogether
+                if not handle._keep_low_precision_grads:
+                    _cast_grad_to_param_dtype(state, handle, param.grad, param)
+                sharded_grad = param.grad.data
+
+            if handle._config.offload_params:
+                # Offload the gradient to CPU to ensure parameters and
+                # gradients are on the same device as required by the optimizer
+                param._cpu_grad.copy_(  # type: ignore[attr-defined]
+                    sharded_grad.detach(), non_blocking=True
+                )  # synchronized in the post-backward callback
+                # Since the sharded gradient is produced in the post-backward
+                # stream and consumed later in the computation stream, inform
+                # the caching allocator
+                sharded_grad.data.record_stream(torch.cuda.current_stream())
+
+            # Since the unsharded gradient is produced in the computation
+            # stream and consumed in the post-backward stream, inform the
+            # caching allocator (before it goes out of scope)
+            unsharded_grad_data.record_stream(state._streams["post_backward"])
+
+            if handle._use_orig_params:
+                # Since the handle's `FlatParameter` completed its gradient
+                # computation, we should reset the gradient noneness mask
+                handle._reset_is_grad_none()
+                # Delay using sharded gradient views until after the
+                # reduce-scatter instead of immediately after resharding
+                handle._use_sharded_grad_views()
+
+
+@no_type_check
+def _should_free_in_backward(
+    state: _State,
+    handle: FlatParamHandle,
+) -> bool:
+    """
+    Returns whether FSDP should free the unsharded flattened parameter in the
+    post-backward or not.
+    """
+    return (
+        state._sync_gradients and handle.uses_sharded_strategy
+    ) or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
+
+
+@no_type_check
+def _cast_grad_to_param_dtype(
+    state: _State,
+    handle: FlatParamHandle,
+    sharded_grad: torch.Tensor,
+    param: FlatParameter,
+):
+    """
+    Casts ``sharded_grad`` back to the full parameter dtype so that the
+    optimizer step runs with that dtype. This performs an actual cast if
+    1. parameters were in reduced precision during the forward since then
+    gradients would be in that reduced precision, or
+    2. parameters were not in reduced precision but gradients were in
+    reduced precision for communication.
+    However, if a low precision communication hook is registered, then this
+    dtype cast happens in the hook instead.
+    """
+    _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
+    if not _low_precision_hook_enabled(state) and (
+        handle._uses_param_mixed_precision or handle._uses_reduce_mixed_precision
+    ):
+        low_prec_grad_data = sharded_grad.data
+        sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
+        # Since for `NO_SHARD`, the gradient is produced in the computation
+        # stream and consumed here in the post-backward stream, inform the
+        # caching allocator; for the sharded strategies, the gradient is
+        # produced in the post-backward stream, so this `record_stream()`
+        # should be a no-op
+        low_prec_grad_data.record_stream(torch.cuda.current_stream())
+
+
+def _check_comm_hook(
+    comm_hook: Any,
+    comm_hook_state: Any,
+) -> None:
+    p_assert(comm_hook is not None, "Communication hook should not be `None`")
+    p_assert(
+        comm_hook_state is not None, "Communication hook state should not be `None`"
+    )
+
+
+def _check_grad_to_accumulate(
+    new_sharded_grad: torch.Tensor,
+    accumulated_grad: torch.Tensor,
+) -> None:
+    p_assert(
+        accumulated_grad.shape == new_sharded_grad.shape,
+        "Shape mismatch when accumulating gradients: "
+        f"existing gradient shape={accumulated_grad.shape} "
+        f"new gradient shape={new_sharded_grad.shape}",
+    )
+    p_assert(
+        accumulated_grad.device == new_sharded_grad.device,
+        "Device mismatch when accumulating gradients: "
+        f"existing gradient device={accumulated_grad.device} "
+        f"new gradient device={new_sharded_grad.device}",
+    )
+
+
+@no_type_check
+def _low_precision_hook_enabled(state: _State) -> bool:
+    return state._communication_hook in LOW_PRECISION_HOOKS
+
+
+@no_type_check
+def _prefetch_handles(
+    state: _State,
+    current_handles_key: _HandlesKey,
+) -> None:
+    """
+    Prefetches the next handles if needed (without synchronization). An empty
+    handles key cannot prefetch.
+    """
+    if not current_handles_key:
+        return
+    handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
+    for handles_key in handles_to_prefetch:
+        # Prefetch the next set of handles without synchronizing to allow
+        # the sync to happen as late as possible to maximize overlap
+        _unshard(
+            state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
+        )
+        state._handles_prefetched[handles_key] = True
+
+
+@no_type_check
+def _get_handles_to_prefetch(
+    state: _State,
+    current_handles_key: _HandlesKey,
+) -> List[_HandlesKey]:
+    """
+    Returns a :class:`list` of the handles keys to prefetch for the next
+    module(s), where ``current_handles_key`` represents the current module.
+
+    "Prefetching" refers to running the unshard logic early (without
+    synchronization), and the "next" modules depend on the recorded execution
+    order and the current training state.
+    """
+    training_state = _get_training_state(current_handles_key)
+    valid_training_states = (
+        HandleTrainingState.BACKWARD_PRE,
+        HandleTrainingState.BACKWARD_POST,
+        HandleTrainingState.FORWARD,
+    )
+    p_assert(
+        training_state in valid_training_states,
+        f"Prefetching is only supported in {valid_training_states} but "
+        f"currently in {training_state}",
+    )
+    eod = state._exec_order_data
+    target_handles_keys: List[_HandlesKey] = []
+    if (
+        training_state == HandleTrainingState.BACKWARD_PRE
+        and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
+    ) or (
+        training_state == HandleTrainingState.BACKWARD_POST
+        and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
+    ):
+        target_handles_keys = [
+            target_handles_key
+            for target_handles_key in eod.get_handles_to_backward_prefetch(
+                current_handles_key
+            )
+            if state._needs_pre_backward_unshard.get(target_handles_key, False)
+            and not state._handles_prefetched.get(target_handles_key, False)
+        ]
+    elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
+        target_handles_keys = [
+            target_handles_key
+            for target_handles_key in eod.get_handles_to_forward_prefetch(
+                current_handles_key
+            )
+            if state._needs_pre_forward_unshard.get(target_handles_key, False)
+            and not state._handles_prefetched.get(target_handles_key, False)
+        ]
+    return target_handles_keys
+
+
+def _get_training_state(
+    handles_key: _HandlesKey,
+) -> HandleTrainingState:
+    """Returns the training state of the handles in ``handles_key``."""
+    p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
+    training_states = set(handle._training_state for handle in handles_key)
+    p_assert(
+        len(training_states) == 1,
+        f"Expects uniform training state but got {training_states}",
+    )
+    return next(iter(training_states))
+
+
+def _register_post_backward_hooks(
+    state: _State,
+    handles: List[FlatParamHandle],
+) -> None:
+    """
+    Registers post-backward hooks on the ``FlatParameter`` s'
+    ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
+
+    The ``AccumulateGrad`` object represents the last function that finalizes
+    the ``FlatParameter`` 's gradient, so it only runs after its entire
+    gradient computation has finished.
+
+    We register the post-backward hook only once in the *first* forward that a
+    ``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
+    object being preserved through multiple forwards.
+    """
+    # If there is no gradient computation, then there is no need for
+    # post-backward logic
+    if not torch.is_grad_enabled():
+        return
+    for handle in handles:
+        flat_param = handle.flat_param
+        already_registered = hasattr(flat_param, "_post_backward_hook_state")
+        if already_registered or not flat_param.requires_grad:
+            continue
+        # Get the `AccumulateGrad` object
+        temp_flat_param = flat_param.expand_as(flat_param)
+        p_assert(
+            temp_flat_param.grad_fn is not None,
+            "The `grad_fn` is needed to access the `AccumulateGrad` and "
+            "register the post-backward hook",
+        )
+        acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
+        hook_handle = acc_grad.register_hook(
+            functools.partial(_post_backward_hook, state, handle)
+        )
+        flat_param._post_backward_hook_state = (acc_grad, hook_handle)  # type: ignore[attr-defined]
+
+
 def _wait_for_computation_stream(
     computation_stream: torch.cuda.Stream,
     unshard_stream: torch.cuda.Stream,
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 5fd1308..50f3d5f 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -106,7 +106,7 @@
     offload_params: bool
     low_prec_param_dtype: Optional[torch.dtype]
     low_prec_reduce_dtype: Optional[torch.dtype]
-    keep_low_precision_grads: Optional[bool] = False
+    keep_low_precision_grads: bool = False
 
 
 class FlatParameter(nn.Parameter):
@@ -1801,6 +1801,14 @@
         return self._config.low_prec_param_dtype is not None
 
     @property
+    def _uses_reduce_mixed_precision(self) -> bool:
+        return self._config.low_prec_reduce_dtype is not None
+
+    @property
+    def _keep_low_precision_grads(self) -> bool:
+        return self._config.keep_low_precision_grads
+
+    @property
     def _force_full_precision(self) -> bool:
         return (
             self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 412a7e0..3d92a96 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -24,7 +24,6 @@
 import torch
 import torch.distributed as dist
 import torch.nn as nn
-import torch.nn.functional as F
 from torch.autograd import Variable
 from torch.distributed import ProcessGroup
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -54,9 +53,12 @@
 )
 from torch.distributed.fsdp._runtime_utils import (
     _clear_grads_if_needed,
+    _pre_forward,
+    _prefetch_handles,
     _prepare_forward_inputs,
     _reshard,
     _reshard_grads,
+    _should_free_in_backward,
     _unshard,
     _unshard_grads,
     _wait_for_computation_stream,
@@ -87,12 +89,7 @@
     _pre_load_state_dict_hook,
 )
 from ._utils import _apply_to_tensors, _free_storage, p_assert
-from .flat_param import (
-    _HandlesKey,
-    FlatParameter,
-    FlatParamHandle,
-    HandleShardingStrategy,
-)
+from .flat_param import FlatParameter, FlatParamHandle, HandleShardingStrategy
 from .wrap import ParamExecOrderWrapPolicy
 
 _TORCH_FX_AVAIL = True
@@ -926,92 +923,6 @@
         # CPU offloading (H2D copy) and mixed precision (low precision cast).
         self._streams["pre_unshard"] = torch.cuda.Stream()
 
-    def _prefetch_handles(
-        self,
-        current_handles_key: _HandlesKey,
-    ) -> None:
-        """
-        Prefetches the next handles if needed (without synchronization). An
-        empty handles key cannot prefetch.
-        """
-        if not current_handles_key:
-            return
-        handles_to_prefetch = self._get_handles_to_prefetch(current_handles_key)
-        for handles_key in handles_to_prefetch:
-            # Prefetch the next set of handles without synchronizing to allow
-            # the sync to happen as late as possible to maximize overlap
-            _unshard(
-                self,
-                handles_key,
-                self._streams["unshard"],
-                self._streams["pre_unshard"],
-            )
-            self._handles_prefetched[handles_key] = True
-
-    def _get_handles_to_prefetch(
-        self,
-        current_handles_key: _HandlesKey,
-    ) -> List[_HandlesKey]:
-        """
-        Returns a :class:`list` of the handles keys to prefetch for the next
-        module(s), where ``current_handles_key`` represents the current module.
-
-        "Prefetching" refers to running the unshard logic early (without
-        synchronization), and the "next" modules depend on the recorded
-        execution order and the current training state.
-        """
-        training_state = self._get_training_state(current_handles_key)
-        valid_training_states = (
-            HandleTrainingState.BACKWARD_PRE,
-            HandleTrainingState.BACKWARD_POST,
-            HandleTrainingState.FORWARD,
-        )
-        p_assert(
-            training_state in valid_training_states,
-            f"Prefetching is only supported in {valid_training_states} but "
-            f"currently in {training_state}",
-        )
-        eod = self._exec_order_data
-        target_handles_keys: List[_HandlesKey] = []
-        if (
-            training_state == HandleTrainingState.BACKWARD_PRE
-            and self.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
-        ) or (
-            training_state == HandleTrainingState.BACKWARD_POST
-            and self.backward_prefetch == BackwardPrefetch.BACKWARD_POST
-        ):
-            target_handles_keys = [
-                target_handles_key
-                for target_handles_key in eod.get_handles_to_backward_prefetch(
-                    current_handles_key
-                )
-                if self._needs_pre_backward_unshard.get(target_handles_key, False)
-                and not self._handles_prefetched.get(target_handles_key, False)
-            ]
-        elif training_state == HandleTrainingState.FORWARD and self.forward_prefetch:
-            target_handles_keys = [
-                target_handles_key
-                for target_handles_key in eod.get_handles_to_forward_prefetch(
-                    current_handles_key
-                )
-                if self._needs_pre_forward_unshard.get(target_handles_key, False)
-                and not self._handles_prefetched.get(target_handles_key, False)
-            ]
-        return target_handles_keys
-
-    def _get_training_state(
-        self,
-        handles_key: _HandlesKey,
-    ) -> HandleTrainingState:
-        """Returns the training state of the handles in ``handles_key``."""
-        p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
-        training_states = set(handle._training_state for handle in handles_key)
-        p_assert(
-            len(training_states) == 1,
-            f"Expects uniform training state but got {training_states}",
-        )
-        return next(iter(training_states))
-
     @staticmethod
     def set_state_dict_type(
         module: nn.Module,
@@ -1291,7 +1202,9 @@
                 self._handles,
                 free_unsharded_flat_params,
             )
-            self._pre_forward(self._handles, unshard_fn, unused, unused)
+            _pre_forward(
+                self, self._handles, unshard_fn, self._fsdp_wrapped_module, unused
+            )
             for handle in self._handles:
                 p_assert(
                     handle.flat_param.device == self.compute_device,
@@ -1301,38 +1214,6 @@
             output = self._fsdp_wrapped_module(*args, **kwargs)
             return self._post_forward(self._handles, reshard_fn, unused, unused, output)
 
-    def _pre_forward(
-        self,
-        handles: List[FlatParamHandle],
-        unshard_fn: Optional[Callable],
-        module: nn.Module,
-        input: Any,
-    ):
-        """
-        Runs the pre-forward logic. This includes an opportunity to unshard
-        currently sharded parameters such as those for the current forward and
-        registering post-backward hooks for these current parameters.
-
-        Args:
-            handles (List[FlatParamHandle]): Handles giving the parameters
-                used in the current forward.
-            unshard_fn (Optional[Callable]): A callable to unshard any
-                currently sharded parameters or ``None`` to not do any
-                unsharding.
-            module (nn.Module): Unused; expected by the hook signature.
-            input (Any): Unused; expected by the hook signature.
-        """
-        self.training_state = TrainingState.FORWARD_BACKWARD
-        self._exec_order_data.record_pre_forward(handles, self.training)
-        for handle in handles:
-            handle._training_state = HandleTrainingState.FORWARD
-        if unshard_fn is not None:
-            unshard_fn()
-        # Register post-backward hooks to reshard the parameters and
-        # reduce-scatter their gradients. They must be re-registered every
-        # forward pass in case the `grad_fn` is mutated.
-        self._register_post_backward_hooks(handles)
-
     def _pre_forward_unshard(
         self,
         handles: List[FlatParamHandle],
@@ -1345,7 +1226,7 @@
             handles_key = tuple(handles)
             self._needs_pre_forward_unshard[handles_key] = False
             torch.cuda.current_stream().wait_stream(self._streams["unshard"])
-            self._prefetch_handles(handles_key)
+            _prefetch_handles(self, handles_key)
 
     def _post_forward(
         self,
@@ -1891,7 +1772,7 @@
                 # Set this to `False` to ensure that a mistargeted prefetch
                 # does not actually unshard these handles
                 self._needs_pre_backward_unshard[_handles_key] = False
-                self._prefetch_handles(_handles_key)
+                _prefetch_handles(self, _handles_key)
                 for handle in _handles:
                     handle.prepare_gradient_for_backward()
                 self._ran_pre_backward_hook[_handles_key] = True
@@ -1904,261 +1785,6 @@
 
         return _apply_to_tensors(_register_hook, outputs)
 
-    def _register_post_backward_hooks(
-        self,
-        handles: List[FlatParamHandle],
-    ) -> None:
-        """
-        Registers post-backward hooks on the ``FlatParameter`` s'
-        ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
-
-        The ``AccumulateGrad`` object represents the last function that
-        finalizes the ``FlatParameter`` 's gradient, so it only runs after its
-        entire gradient computation has finished.
-
-        We register the post-backward hook only once in the *first* forward
-        that a ``FlatParameter`` participates in. This relies on the
-        ``AccumulateGrad`` object being preserved through multiple forwards.
-        """
-        # If there is no gradient computation, then there is no need for
-        # post-backward logic
-        if not torch.is_grad_enabled():
-            return
-        for handle in handles:
-            flat_param = handle.flat_param
-            already_registered = hasattr(flat_param, "_post_backward_hook_state")
-            if already_registered or not flat_param.requires_grad:
-                continue
-            # Get the `AccumulateGrad` object
-            temp_flat_param = flat_param.expand_as(flat_param)
-            p_assert(
-                temp_flat_param.grad_fn is not None,
-                "The `grad_fn` is needed to access the `AccumulateGrad` and "
-                "register the post-backward hook",
-            )
-            acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
-            hook_handle = acc_grad.register_hook(
-                functools.partial(self._post_backward_hook, handle)
-            )
-            flat_param._post_backward_hook_state = (acc_grad, hook_handle)  # type: ignore[attr-defined]
-
-    @torch.no_grad()
-    def _post_backward_hook(
-        self,
-        handle: FlatParamHandle,
-        *unused: Any,
-    ) -> None:
-        """
-        Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
-
-        Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
-        unsharded gradient for the local batch.
-
-        Postcondition:
-        - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
-        unsharded gradient.
-        - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
-        gradient (accumulating with any existing gradient).
-        """
-        param = handle.flat_param
-        param._post_backward_called = True
-        with torch.autograd.profiler.record_function(
-            "FullyShardedDataParallel._post_backward_hook"
-        ):
-            self._assert_state([TrainingState.FORWARD_BACKWARD])
-            self.training_state = TrainingState.FORWARD_BACKWARD
-            p_assert(
-                handle._training_state == HandleTrainingState.BACKWARD_PRE,
-                f"Expects `BACKWARD_PRE` state but got {handle._training_state}",
-            )
-            handle._training_state = HandleTrainingState.BACKWARD_POST
-
-            if (
-                self._use_param_exec_order_policy()
-                and self._param_exec_order_prep_stage
-            ):
-                # In self._fsdp_params_exec_order, the parameters are ordered based on
-                # the execution order in the backward pass in the first iteration.
-                self._fsdp_params_exec_order.append(param)
-
-            if param.grad is None:
-                return
-            if param.grad.requires_grad:
-                raise RuntimeError(
-                    "FSDP only works with gradients that don't require gradients"
-                )
-
-            free_unsharded_flat_param = self._should_free_unsharded_flat_param(handle)
-            _reshard(self, [handle], [free_unsharded_flat_param])
-
-            # TODO (awgu): Post-backward prefetching does not support the
-            # multiple handles per module case (which was why we keyed by
-            # *tuple*). The post-backward hook runs per handle, not per group
-            # of handles. To generalize this, we may need a 2-level mapping,
-            # where we map each individual handle to its groups of handles and
-            # then from the groups of handles to their indices in the order.
-            handles_key = (handle,)
-            self._prefetch_handles(handles_key)
-
-            if not self._sync_gradients:
-                return
-
-            # Wait for all ops in the current stream (e.g. gradient
-            # computation) to finish before reduce-scattering the gradient
-            self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
-
-            with torch.cuda.stream(self._streams["post_backward"]):
-                orig_grad_data = param.grad.data
-                if (
-                    self._mixed_precision_enabled_for_reduce()
-                    and not self._low_precision_hook_enabled()
-                ):
-                    # Cast gradient to precision in which it should be communicated.
-                    # If a low precision hook is registered and reduce_dtype is specified
-                    # in `MixedPrecision`, communication hook will take care of
-                    # casting to lower precision and back.
-                    # TODO: Make this a communication hook when communication hooks
-                    # are implemented for FSDP. Note that this is a noop if the
-                    # reduce_dtype matches the param dtype.
-                    param.grad.data = param.grad.data.to(
-                        self.mixed_precision.reduce_dtype
-                    )
-
-                if self._exec_order_data.is_first_iter:
-                    # For all sharding strategies communication is performed through `_communication_hook`:
-                    # default comm hooks are: `reduce_scatter` for sharded strategies and
-                    # `all_reduce` for non-sharded strategies. This checks asserts that `_communication_hook`
-                    # and `_communication_hook_state`, required for communication not `None`.`
-                    p_assert(
-                        self._communication_hook is not None,
-                        "Communication hook should not be None",
-                    )
-                    p_assert(
-                        self._communication_hook_state is not None,
-                        "Communication hook state should not be None",
-                    )
-                grad = param.grad.data
-                if handle.uses_sharded_strategy:
-                    # We clear `param.grad` to permit repeated gradient
-                    # computations when this FSDP module is called multiple times.
-                    # This is to avoid a race among multiple re-entrant backward
-                    # passes. For example, the second backward pass computation
-                    # precedes ahead of the first backward pass reduction, which is
-                    # possible since the reduction is in a different stream and is
-                    # async. Then, the first backward pass may be incorrectly
-                    # reducing the second backward pass's `param.grad`.
-                    # The reduced gradients are accumulated in
-                    # `param._saved_grad_shard`, and the gradient reductions can
-                    # happen in arbitrary order, though we tolerate this due to the
-                    # (approximate) commutativity of floating-point addition.
-                    param.grad = None
-                    grad_flatten = torch.flatten(grad)
-                    chunks = list(grad_flatten.chunk(self.world_size))
-                    num_pad = self.world_size * chunks[0].numel() - grad.numel()
-                    input_flattened = F.pad(grad_flatten, [0, num_pad])
-                    output = torch.zeros_like(chunks[0])
-                    self._communication_hook(
-                        self._communication_hook_state, input_flattened, output
-                    )
-
-                    self._cast_grad_to_param_dtype(output, param)
-
-                    # To support gradient accumulation outside `no_sync()`, we save
-                    # the gradient data to `param._saved_grad_shard` before the
-                    # backward pass, accumulate gradients into it here, and set
-                    # `param.grad` with the accumulated value at the end of the
-                    # backward pass in preparation for the optimizer step.
-                    accumulate_grad = hasattr(param, "_saved_grad_shard")
-                    if accumulate_grad:
-                        p_assert(
-                            param._saved_grad_shard.shape == output.shape,  # type: ignore[attr-defined]
-                            "Shape mismatch when accumulating gradients: "  # type: ignore[attr-defined]
-                            f"existing grad shape={param._saved_grad_shard.shape} "
-                            f"new grad shape={output.shape}",  # type: ignore[attr-defined]
-                        )
-                        p_assert(
-                            param._saved_grad_shard.device == output.device,  # type: ignore[attr-defined]
-                            "Device mismatch when accumulating gradients: "  # type: ignore[attr-defined]
-                            f"existing grad device={param._saved_grad_shard.device} "
-                            f"new grad device={output.device}",  # type: ignore[attr-defined]
-                        )
-                        param._saved_grad_shard += output  # type: ignore[attr-defined]
-                    else:
-                        param._saved_grad_shard = output  # type: ignore[attr-defined]
-                    grad = param._saved_grad_shard  # type: ignore[attr-defined]
-                else:
-                    if self.sharding_strategy == ShardingStrategy.NO_SHARD:
-                        self._communication_hook(
-                            self._communication_hook_state, param.grad
-                        )
-
-                    # For NO_SHARD keeping grads in the reduced precision, we
-                    # can simply omit the cast as needed, we can't do this for
-                    # other sharding strategies because grad field is assigned
-                    # in _finalize_params. TODO (rvarm1) this divergence in
-                    # logic is not ideal.
-                    if not self._mixed_precision_keep_low_precision_grads():
-                        self._cast_grad_to_param_dtype(param.grad, param)
-
-                # Regardless of sharding or not, offload the grad to CPU if we are
-                # offloading params. This is so param and grad reside on same device
-                # which is needed for the optimizer step.
-                if handle._config.offload_params:
-                    # We specify non_blocking=True
-                    # and ensure the appropriate synchronization is done by waiting
-                    # streams in _wait_for_post_backward.
-                    param._cpu_grad.copy_(  # type: ignore[attr-defined]
-                        grad.detach(), non_blocking=True
-                    )
-                    # Don't let this memory get reused until after the transfer.
-                    grad.data.record_stream(torch.cuda.current_stream())
-
-                # After _post_backward_hook returns, orig_grad_data will eventually
-                # go out of scope, at which point it could otherwise be freed for
-                # further reuse by the main stream while the div/reduce_scatter/copy
-                # are underway in the post_backward stream. See:
-                # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
-                orig_grad_data.record_stream(self._streams["post_backward"])
-
-                if handle._use_orig_params:
-                    # Since the handle's `FlatParameter` completed its gradient
-                    # computation, we should reset the gradient noneness mask
-                    handle._reset_is_grad_none()
-                    # Delay using sharded gradient views until after the
-                    # reduce-scatter instead of immediately after resharding
-                    handle._use_sharded_grad_views()
-
-    def _cast_grad_to_param_dtype(
-        self,
-        grad: torch.Tensor,
-        param: FlatParameter,
-    ):
-        """
-        Casts gradient ``grad`` back to the full parameter dtype so that the
-        optimizer step runs with that dtype. This performs an actual cast if
-        1. parameters were in reduced precision during the forward since then
-        gradients would be in that reduced precision, or
-        2. parameters were not in reduced precision but gradients were in
-        reduced precision for communication.
-        However, if a low precision communication hook is registered, then this
-        dtype cast happens in the hook instead.
-        """
-        self._assert_state(TrainingState.FORWARD_BACKWARD)
-        if not self._low_precision_hook_enabled() and (
-            self._mixed_precision_enabled_for_params()
-            or self._mixed_precision_enabled_for_reduce()
-        ):
-            low_prec_grad_data = grad.data
-            grad.data = grad.data.to(dtype=param.dtype)
-            # Do not let the low precision gradient memory get reused until
-            # the cast to full parameter precision completes
-            low_prec_grad_data.record_stream(torch.cuda.current_stream())
-
-    def _should_free_unsharded_flat_param(self, handle: FlatParamHandle):
-        return (
-            self._sync_gradients and handle.uses_sharded_strategy
-        ) or handle._config.sharding_strategy == HandleShardingStrategy.FULL_SHARD
-
     def _queue_wait_for_post_backward(self) -> None:
         """
         Queues a post-backward callback from the root FSDP instance, which
@@ -2220,7 +1846,7 @@
                     if already_resharded:
                         continue
                     free_unsharded_flat_params.append(
-                        self._should_free_unsharded_flat_param(handle)
+                        _should_free_in_backward(fsdp_module, handle)
                     )
                     handles_to_reshard.append(handle)
                 _reshard(self, handles_to_reshard, free_unsharded_flat_params)