[FSDP] Skip `_use_sharded_views()` for `SHARD_GRAD_OP` (#98250)

This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`).

<details>
<summary>(Before) Pre-backward hook: 4.356 ms</summary>

<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">

</details>

<details>
<summary>(After) Pre-backward hook: 1.044 ms</summary>

![Screenshot 2023-04-04 at 9 05 53 AM](https://user-images.githubusercontent.com/31054793/229800917-9580ce6b-3721-469a-9212-f0cbfd8cbb52.png)

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98250
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py
index 10bf09b..3100135 100644
--- a/test/distributed/fsdp/test_fsdp_use_orig_params.py
+++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py
@@ -18,6 +18,7 @@
     ShardingStrategy,
 )
 from torch.distributed.fsdp._common_utils import clean_tensor_name
+from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
 from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy
 from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
 from torch.nn.parallel.distributed import DistributedDataParallel as DDP
@@ -500,6 +501,10 @@
 
         # Check that FSDP correctly exposes gradients even after forward
         # (namely, `None` for weights and non-`None` for biases)
+        if sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES:
+            # Skip the check since we do not expose the gradients after forward
+            # for these strategies
+            return
         for (ddp_n, ddp_p), (fsdp_n, fsdp_p) in zip(
             ddp_model.module.named_parameters(),
             fsdp_model.named_parameters(),
@@ -759,7 +764,9 @@
             def get_loss(self, inp, out):
                 return out.sum()
 
-        def check_parameter_parity(ddp_model, fsdp_model):
+        def check_parameter_parity(
+            ddp_model: DDP, fsdp_model: FSDP, between_fwd_and_bwd: bool
+        ):
             assert self.rank in (
                 0,
                 1,
@@ -773,6 +780,13 @@
                     # For `NO_SHARD`, do nothing since the original parameters
                     # are unflattened
                     pass
+                elif (
+                    between_fwd_and_bwd
+                    and sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES
+                ):
+                    # For no reshard after forward strategies, do nothing since
+                    # FSDP did not use sharded views after forward
+                    pass
                 # Otherwise, case on the parameter (see the model definition)
                 elif n1 == "lin1.weight":
                     if self.rank == 0:
@@ -806,7 +820,7 @@
         inp = fsdp_model.get_input(device)
         ddp_out = ddp_model(*inp)
         fsdp_out = fsdp_model(*inp)
-        check_parameter_parity(ddp_model, fsdp_model)
+        check_parameter_parity(ddp_model, fsdp_model, True)
 
         ddp_loss = ddp_model.module.get_loss(inp, ddp_out)
         fsdp_loss = fsdp_model.get_loss(inp, fsdp_out)
@@ -814,23 +828,23 @@
         fsdp_loss.backward()
         ddp_optim.step()
         fsdp_optim.step()
-        check_parameter_parity(ddp_model, fsdp_model)
+        check_parameter_parity(ddp_model, fsdp_model, False)
 
         inp = fsdp_model.get_input(device)
         ddp_out = ddp_model(*inp)
         fsdp_out = fsdp_model(*inp)
-        check_parameter_parity(ddp_model, fsdp_model)
+        check_parameter_parity(ddp_model, fsdp_model, True)
 
 
 class TestFSDPUseOrigParamsWriteback(FSDPTest):
     """Tests parameter and gradient writeback."""
 
     class Model(nn.Module):
-        def __init__(self):
+        def __init__(self, device: torch.device):
             super().__init__()
             torch.manual_seed(42)
-            self.lin1 = nn.Linear(5, 5, bias=True)
-            self.lin2 = nn.Linear(5, 7, bias=True)
+            self.lin1 = nn.Linear(5, 5, bias=True, device=device)
+            self.lin2 = nn.Linear(5, 7, bias=True, device=device)
 
         def forward(self, x: torch.Tensor) -> torch.Tensor:
             z = self.lin1(x)
@@ -876,10 +890,12 @@
 
         # Check that the writeback propagates
         ddp_model = DDP(
-            TestFSDPUseOrigParamsWriteback.Model().cuda(), device_ids=[self.rank]
+            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+            device_ids=[self.rank],
         )
         fsdp_model = FSDP(
-            TestFSDPUseOrigParamsWriteback.Model().cuda(), use_orig_params=True
+            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+            use_orig_params=True,
         )
         ddp = ddp_model.module  # for brevity
         fsdp = fsdp_model.module
@@ -927,10 +943,12 @@
             return None if set_to_none else torch.ones_like(param) * 2
 
         ddp_model = DDP(
-            TestFSDPUseOrigParamsWriteback.Model().cuda(), device_ids=[self.rank]
+            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+            device_ids=[self.rank],
         )
         fsdp_model = FSDP(
-            TestFSDPUseOrigParamsWriteback.Model().cuda(), use_orig_params=True
+            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+            use_orig_params=True,
         )
         LR = 1e-2
         # TODO: If we add `summon_full_params(with_grads=True)`, then replace
@@ -982,7 +1000,8 @@
     @skip_if_lt_x_gpu(2)
     def test_writeback_shape_mismatch(self):
         fsdp_model = FSDP(
-            TestFSDPUseOrigParamsWriteback.Model().cuda(), use_orig_params=True
+            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
+            use_orig_params=True,
         )
         # Check that writing back with mismatched shape errors
         fsdp = fsdp_model.module  # for brevity
@@ -1019,6 +1038,41 @@
             with FSDP.summon_full_params(fsdp_model):  # triggers a writeback
                 ...
 
+    @skip_if_lt_x_gpu(2)
+    def test_writeback_between_fwd_and_bwd_for_no_reshard_raises(self):
+        fsdp_kwargs = {
+            "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
+            "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
+            "use_orig_params": True,
+        }
+        fsdp_wrapper = functools.partial(FSDP, **fsdp_kwargs)
+
+        # Test changing the parameter storage to no longer be a view into the
+        # flat parameter
+        fsdp_model = fsdp_wrapper(
+            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
+        )
+        inp = fsdp_model.get_input(torch.device("cuda"))
+        loss = fsdp_model(*inp).sum()
+        fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone()
+        assert_msg = (
+            "FSDP does not support changing the parameters between forward and backward"
+        )
+        with self.assertRaisesRegex(AssertionError, assert_msg):
+            loss.backward()
+
+        # Test changing the parameter variable itself
+        fsdp_model = fsdp_wrapper(
+            TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
+        )
+        inp = fsdp_model.get_input(torch.device("cuda"))
+        loss = fsdp_model(*inp).sum()
+        fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter(
+            fsdp_model.lin1.weight.clone()
+        )
+        with self.assertRaisesRegex(AssertionError, assert_msg):
+            loss.backward()
+
 
 class TestFSDPUseOrigParamsFQNs(FSDPTest):
     @skip_if_lt_x_gpu(2)
diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py
index 8f97bcf..3c97b30 100644
--- a/torch/distributed/fsdp/_init_utils.py
+++ b/torch/distributed/fsdp/_init_utils.py
@@ -75,11 +75,14 @@
     ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
     ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
 }
-
 HYBRID_SHARDING_STRATEGIES = {
     ShardingStrategy.HYBRID_SHARD,
     ShardingStrategy._HYBRID_SHARD_ZERO2,
 }
+NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
+    ShardingStrategy.SHARD_GRAD_OP,
+    ShardingStrategy._HYBRID_SHARD_ZERO2,
+)
 
 
 # NOTE: Since non-self attributes cannot be type annotated, several attributes
diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py
index 511d172..c89bcd4 100644
--- a/torch/distributed/fsdp/_runtime_utils.py
+++ b/torch/distributed/fsdp/_runtime_utils.py
@@ -35,13 +35,10 @@
     FlatParamHandle,
     HandleShardingStrategy,
     HandleTrainingState,
+    RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
 )
 from torch.distributed.utils import _apply_to_tensors, _p_assert, _to_kwargs
 
-RESHARD_AFTER_FORWARD_STRATEGIES = {
-    HandleShardingStrategy.FULL_SHARD,
-    HandleShardingStrategy.HYBRID_SHARD,
-}
 
 # Do not include "process_group" to enable hybrid shard and MoE cases
 HOMOGENEOUS_ATTR_NAMES = (
@@ -501,7 +498,7 @@
     # computation (though this may not be true)
     free_unsharded_flat_params = [
         not state._is_root
-        and handle._sharding_strategy in RESHARD_AFTER_FORWARD_STRATEGIES
+        and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
         for handle in handles
     ]
     _reshard(state, handles, free_unsharded_flat_params)
@@ -830,7 +827,7 @@
     # higher throughput.
     return (
         state._sync_gradients
-        or handle._sharding_strategy in RESHARD_AFTER_FORWARD_STRATEGIES
+        or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
     )
 
 
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 97f81f6..3cc1a6c 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -81,6 +81,25 @@
 _FLAT_PARAM_PADDING_VALUE = 42
 
 
+# TODO: Define this for now to avoid circular imports. See if we can remove.
+class HandleShardingStrategy(Enum):
+    FULL_SHARD = auto()
+    SHARD_GRAD_OP = auto()
+    NO_SHARD = auto()
+    HYBRID_SHARD = auto()
+    _HYBRID_SHARD_ZERO2 = auto()
+
+
+RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
+    HandleShardingStrategy.FULL_SHARD,
+    HandleShardingStrategy.HYBRID_SHARD,
+)
+NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
+    HandleShardingStrategy.SHARD_GRAD_OP,
+    HandleShardingStrategy._HYBRID_SHARD_ZERO2,
+)
+
+
 class ParamInfo(NamedTuple):
     """Information for an original parameter."""
 
@@ -144,17 +163,6 @@
     param_offsets: Tuple[Tuple[int, int], ...]
 
 
-# TODO (awgu): Prefix these with "Handle" for now to avoid circular imports and
-# inadvertent misuses; coalesce with those in fully_sharded_data_parallel.py
-# later
-class HandleShardingStrategy(Enum):
-    FULL_SHARD = auto()
-    SHARD_GRAD_OP = auto()
-    NO_SHARD = auto()
-    HYBRID_SHARD = auto()
-    _HYBRID_SHARD_ZERO2 = auto()
-
-
 class FlatParameter(nn.Parameter):
     """
     This is the flat parameter used by :class:`FullyShardedDataParallel`. It is
@@ -417,6 +425,11 @@
         self._training_state = HandleTrainingState.IDLE
         self._debug_level = dist.get_debug_level()
         self._fully_sharded_module = fully_sharded_module
+        # NOTE: For the code path using this flag, we only skip calling
+        # `_use_sharded_views()` and do not skip switching to the sharded flat
+        # parameter since whether `self.flat_param` uses the sharded or
+        # unsharded flat parameter parameterizes behavior.
+        self._skipped_use_sharded_views = False
         # Optimistically assume a valid input `params` and set dtype attributes
         # before `_init_flat_param()`, which performs the actual validation
         self._orig_param_dtype = params[0].dtype
@@ -1199,6 +1212,11 @@
         in_forward = self._training_state == HandleTrainingState.FORWARD
         in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
         if self._use_orig_params:
+            if self._skipped_use_sharded_views and in_pre_backward:
+                # This call corresponds to the complementary pre-backward
+                # `_use_unsharded_views()` to the skipped pre-forward
+                # `_use_sharded_views()`, so we should skip this one too.
+                return
             # We use `Tensor` views in the forward so that they are tracked by
             # autograd. We use them in the pre-backward as well to support
             # reentrant activation checkpointing, which needs the views to be
@@ -1511,12 +1529,26 @@
             )
         flat_param.data = flat_param._local_shard  # type: ignore[attr-defined]
         if self._use_orig_params:
-            self._use_sharded_views()
+            in_forward = self._training_state == HandleTrainingState.FORWARD
+            if (
+                in_forward
+                and self._sharding_strategy
+                in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
+            ):
+                self._skipped_use_sharded_views = True
+            else:
+                self._use_sharded_views()
             # For the post-forward reshard, we may try to use sharded gradient
             # views (or unsharded gradient views if a gradient was accumulated
             # in `no_sync()`), but for the post-backward reshard, we delay the
             # call to after the reduce-scatter.
-            if self._training_state == HandleTrainingState.FORWARD:
+            if (
+                in_forward
+                # Skip using gradient views if skipped using sharded views
+                # since exposing unsharded parameters with sharded gradients
+                # may be confusing to the user
+                and not self._skipped_use_sharded_views
+            ):
                 # TODO: Change `_unpadded_unsharded_size` if we change the
                 # gradient to be computed directly with padding.
                 accumulated_grad_in_no_sync = (
@@ -1755,6 +1787,7 @@
         printability. Parameters whose data is present must preserve their
         variables to be passable to an optimizer.
         """
+        self._skipped_use_sharded_views = False
         if not self.uses_sharded_strategy:
             # For `NO_SHARD`, use the *unflattened* unsharded views since we
             # have the unsharded parameter
@@ -1869,12 +1902,21 @@
             but no longer has the expected flattened shape.
         Returns: ``True`` if some writeback happened, and ``False`` otherwise.
         """
-        if self.uses_sharded_strategy and not self.is_sharded(self.flat_param):
+        if (
+            self.uses_sharded_strategy
+            and not self.is_sharded(self.flat_param)
+            and not self._skipped_use_sharded_views
+        ):
             # For `NO_SHARD`, we may still need to writeback
             return False
         flat_param = self.flat_param
         wroteback = False
-        flat_param_data_ptr = flat_param.untyped_storage().data_ptr()
+        if self._skipped_use_sharded_views and self.uses_sharded_strategy:
+            flat_param_data_ptr = (
+                self._get_padded_unsharded_flat_param().untyped_storage().data_ptr()
+            )
+        else:
+            flat_param_data_ptr = flat_param.untyped_storage().data_ptr()
         # NOTE: Since this method is called in the pre-unshard, which is only
         # called during computation in the pre-forward or pre-backward, the
         # sharded gradient should be guaranteed to be in `.grad`, not in
@@ -1908,6 +1950,12 @@
                 continue
 
             # Check for parameter writeback
+            if self._skipped_use_sharded_views:
+                param = flat_param._tensors[i]
+                _p_assert(
+                    param is not None,
+                    f"Expects to have saved tensor for {flat_param._fqns[i]}",
+                )
             param_changed = getattr(module, param_name) is not param
             needs_param_writeback = (
                 param_changed  # changed parameter variable itself
@@ -1915,6 +1963,13 @@
                     param, flat_param_data_ptr
                 )  # changed `.data`
             )
+            if self._skipped_use_sharded_views and (
+                param_changed or needs_param_writeback
+            ):
+                raise AssertionError(
+                    "FSDP does not support changing the parameters between "
+                    f"forward and backward for {self._sharding_strategy}"
+                )
             if param_changed:
                 # NOTE: The gradient is not preserved after a parameter change.
                 param = getattr(module, param_name)
@@ -1927,6 +1982,10 @@
                 wroteback = True
 
             # Check for gradient writeback
+            if self._skipped_use_sharded_views:
+                # Skip the writeback check because we do not expose gradients
+                # when we skipped using sharded views
+                continue
             if param.grad is None and flat_param.grad is not None:
                 expected_shape = torch.Size([numel_in_shard])
                 self._writeback_tensor(
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 2e90c84..c110c16 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -199,6 +199,16 @@
         accessing the original parameters between forward and backward may
         raise an illegal memory access.
 
+    .. warning::
+        For ``use_orig_params=True``, ``ShardingStrategy.SHARD_GRAD_OP``
+        exposes the unsharded parameters, not the sharded parameters, after
+        forward since it does not free the unsharded ones, unlike
+        ``ShardingStrategy.FULL_SHARD``. One caveat is that, since gradients
+        are always sharded or ``None``, ``ShardingStrategy.SHARD_GRAD_OP`` will
+        not expose the sharded gradients with the unsharded parameters after
+        forward. If you want to inspect the gradients, try
+        :meth:`summon_full_params` with ``with_grads=True``.
+
     Args:
         module (nn.Module):
             This is the module to be wrapped with FSDP.
diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py
index 75e648f..ae85b49 100644
--- a/torch/testing/_internal/common_fsdp.py
+++ b/torch/testing/_internal/common_fsdp.py
@@ -14,6 +14,7 @@
 import torch.nn as nn
 from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
 from torch.distributed.fsdp._common_utils import TrainingState
+from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
 from torch.distributed.fsdp.fully_sharded_data_parallel import (
     BackwardPrefetch,
     MixedPrecision,
@@ -843,7 +844,14 @@
                         input = tuple(x.half() for x in input)
                 output = model(*input)
                 # Post-forward, if CPU offloading model param should be on CPU.
-                if cpu_offload_params and isinstance(model, FSDP):
+                if (
+                    cpu_offload_params
+                    and isinstance(model, FSDP)
+                    # If not resharding after forward, the parameters are still
+                    # exposed as unsharded views into the GPU flat parameter
+                    and model.sharding_strategy
+                    not in NO_RESHARD_AFTER_FORWARD_STRATEGIES
+                ):
                     for p in model.parameters():
                         # Params should always be on CPU
                         self.assertEqual(p.device, torch.device("cpu"))