[FSDP] Reduce CPU overhead (#96958)

I experimented with 200 `nn.Linear`s with `bias=True` for a total of 400 `nn.Parameter`s all wrapped into the same FSDP instance and world size of 2.

**`unshard()` -> `_use_unsharded_views()`**
- (From previous PR) unsafe `setattr`: 6.112 ms -> 4.268 ms

**`pre_unshard()` -> `_writeback_orig_params()`**
- Factor out `flat_param` and `flat_param_grad` data pointers: ~1.8 ms -> 1.071 ms
    - Now dominated by calling `_typed_storage()` on each original parameter and its gradient

**`reshard()` -> `_use_sharded_views()`**
- Factor out `torch.empty(0, ...)`: ~4.6 - 4.7 ms -> ~2.7 - 2.8 ms
    - Now dominated by `aten::slice()` and (unsafe) `setattr`, which are required

I removed some `assert` calls that were only needed for mypy or if the subsequent call would provide the same error message anyway. These have negligible overhead, but I think it is still okay to remove them and avoid the type check. We need to address type checking more holistically anyway.

---

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96958
Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py
index 45c8c45..378f547 100644
--- a/torch/distributed/fsdp/_utils.py
+++ b/torch/distributed/fsdp/_utils.py
@@ -21,6 +21,10 @@
     return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
 
 
+def _same_storage_as_data_ptr(x: torch.Tensor, data_ptr: int) -> bool:
+    return x._typed_storage()._data_ptr() == data_ptr
+
+
 def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
     with no_dispatch():
         tensor.record_stream(cast(torch._C.Stream, stream))
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 260cb60..a1b195c 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -32,7 +32,7 @@
 from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
 
 from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
-from ._utils import _no_dispatch_record_stream, _same_storage
+from ._utils import _no_dispatch_record_stream, _same_storage_as_data_ptr
 
 __all__ = [
     "FlatParameter",
@@ -1385,6 +1385,7 @@
         )
         return views
 
+    @no_type_check
     def _use_unsharded_views(self, as_params: bool) -> None:
         """
         Unflattens the unsharded flattened parameter by setting the original
@@ -1409,7 +1410,7 @@
                     # variable.
                     self._setattr_param(module, param_name, nn.Parameter(view))
                     continue
-                param = self.flat_param._params[i]  # type: ignore[index]
+                param = self.flat_param._params[i]
                 self._setattr_param(module, param_name, param)
                 param.data = view
             elif as_params:
@@ -1418,42 +1419,29 @@
                 param_var: Tensor = view
                 if self._use_orig_params:
                     if self._training_state == HandleTrainingState.FORWARD:
-                        assert self.flat_param._tensors is not None
                         # Save the `Tensor` for the pre-backward
                         self.flat_param._tensors[i] = view  # save for pre-backward
                     elif self._training_state == HandleTrainingState.BACKWARD_PRE:
                         # Use the saved `Tensor` variable from the forward to
                         # preserve the autograd graph so that the post-backward
                         # hook fires (e.g. for reentrant AC)
-                        assert self.flat_param._tensors is not None  # mypy
                         tensor = self.flat_param._tensors[i]
-                        _p_assert(
-                            tensor is not None,
-                            "Expects `Tensor` to have been saved in forward",
-                        )
-                        tensor.data = view  # type: ignore[union-attr]
-                        assert tensor is not None  # mypy
+                        tensor.data = view
                         param_var = tensor
                 self._setattr_tensor(module, param_name, param_var)
                 if (
                     self._use_orig_params
                     and self._training_state == HandleTrainingState.FORWARD
                 ):
-                    module._parameters[param_name] = param_var  # type: ignore[assignment]
+                    module._parameters[param_name] = param_var
         for i, (
             param_name,
             module,
             _,
             prim_param_name,
             prim_module,
-            prim_module_name,
+            _,
         ) in enumerate(self.flat_param._shared_param_infos):
-            if hasattr(module, param_name):
-                delattr(module, param_name)
-            _p_assert(
-                hasattr(prim_module, prim_param_name),
-                f"Module {prim_module_name} is missing parameter {prim_param_name}",
-            )
             prim_param: Union[Tensor, nn.Parameter] = getattr(
                 prim_module, prim_param_name
             )
@@ -1462,11 +1450,10 @@
                 f"as_params={as_params} type(prim_param)={type(prim_param)}",
             )
             if self._use_orig_params and as_params:
-                shared_param = self.flat_param._shared_params[i]  # type: ignore[index]
+                shared_param = self.flat_param._shared_params[i]
                 self._setattr_param(module, param_name, shared_param)
                 shared_param.data = prim_param
             elif as_params:
-                assert isinstance(prim_param, nn.Parameter)
                 self._setattr_param(module, param_name, prim_param)
             else:
                 self._setattr_tensor(module, param_name, prim_param)
@@ -1474,7 +1461,7 @@
                     self._use_orig_params
                     and self._training_state == HandleTrainingState.FORWARD
                 ):
-                    module._parameters[param_name] = prim_param  # type: ignore[assignment]
+                    module._parameters[param_name] = prim_param
 
     def _use_unsharded_grad_views(self) -> None:
         """
@@ -1554,6 +1541,7 @@
         finally:
             self._use_unsharded_views(as_params=False)
 
+    @no_type_check
     @torch.no_grad()
     def _use_sharded_views(self) -> None:
         """
@@ -1573,31 +1561,30 @@
             self._use_unsharded_views(as_params=True)
             return
         self._check_sharded(self.flat_param)
-        start, end = self.flat_param._shard_indices  # type: ignore[attr-defined]
+        start, end = self.flat_param._shard_indices
         offset = 0
-        assert self.flat_param._params is not None
+        # Construct once and reuse for all parameters not in the local shard
+        size_0_empty_tensor = torch.empty(
+            0,
+            dtype=self.flat_param.dtype,  # in case `flat_param` changed dtype
+            device=self.flat_param.device,
+            requires_grad=False,
+        )
         for i, (param, (param_name, module, _)) in enumerate(
             zip(self.flat_param._params, self.flat_param._param_infos)
         ):
             self._setattr_param(module, param_name, param)
             in_sharded_flat_param = (
-                i >= start
-                and i <= end
-                and self.flat_param._shard_param_offsets  # type: ignore[attr-defined]
+                i >= start and i <= end and self.flat_param._shard_param_offsets
             )
             if in_sharded_flat_param:
-                param_start, param_end = self.flat_param._shard_param_offsets[i - start]  # type: ignore[attr-defined]
+                param_start, param_end = self.flat_param._shard_param_offsets[i - start]
                 numel_in_shard = param_end - param_start + 1
                 param.data = self.flat_param[offset : offset + numel_in_shard]
                 offset += numel_in_shard
             else:
                 # Allow the original data to be freed via garbage collection
-                param.data = torch.empty(
-                    0,
-                    dtype=self.flat_param.dtype,  # in case `flat_param` changed dtype
-                    device=self.flat_param.device,
-                    requires_grad=False,
-                )
+                param.data = size_0_empty_tensor
         assert self.flat_param._shared_params is not None
         for i, (
             param,
@@ -1609,10 +1596,9 @@
             prim_param = getattr(prim_module, prim_param_name)
             param.data = prim_param  # could be both empty and non-empty
         if self._training_state == HandleTrainingState.BACKWARD_POST:
-            assert self.flat_param._tensors is not None  # mypy
             # Clear the saved `Tensor`s since they are unneeded now
             for i in range(len(self.flat_param._tensors)):
-                self.flat_param._tensors[i] = None  # type: ignore[index]
+                self.flat_param._tensors[i] = None
 
     @torch.no_grad()
     def _use_sharded_grad_views(self) -> None:
@@ -1681,6 +1667,7 @@
             else:
                 param.grad = None
 
+    @no_type_check
     @torch.no_grad()
     def _writeback_orig_params(self) -> bool:
         """
@@ -1698,10 +1685,25 @@
             # For `NO_SHARD`, we may still need to writeback
             return False
         flat_param = self.flat_param
-        start, end = flat_param._shard_indices  # type: ignore[attr-defined]
+        start, end = flat_param._shard_indices
         offset = 0
         assert flat_param._params is not None
         wroteback = False
+        flat_param_data_ptr = flat_param.untyped_storage().data_ptr()
+        # NOTE: Since this method is called in the pre-unshard, which is only
+        # called during computation in the pre-forward or pre-backward, the
+        # sharded gradient should be guaranteed to be in `.grad`, not in
+        # `._saved_grad_shard`.
+        flat_param_grad = (
+            flat_param.grad
+            if self.uses_sharded_strategy or not self._offload_params
+            else flat_param._cpu_grad
+        )
+        flat_param_grad_data_ptr = (
+            None
+            if flat_param_grad is None
+            else flat_param_grad.untyped_storage().data_ptr()
+        )
         for i, (param, (param_name, module, _)) in enumerate(
             zip(flat_param._params, flat_param._param_infos)
         ):
@@ -1710,20 +1712,20 @@
                 # (e.g. during model checkpointing)
                 continue
             in_sharded_flat_param = (
-                i >= start
-                and i <= end
-                and self.flat_param._shard_param_offsets  # type: ignore[attr-defined]
+                i >= start and i <= end and self.flat_param._shard_param_offsets
             )
             if not in_sharded_flat_param:
                 continue
-            param_start, param_end = flat_param._shard_param_offsets[i - start]  # type: ignore[attr-defined]
+            param_start, param_end = flat_param._shard_param_offsets[i - start]
             numel_in_shard = param_end - param_start + 1
 
             # Check for parameter writeback
             param_changed = getattr(module, param_name) is not param
             needs_param_writeback = (
                 param_changed  # changed parameter variable itself
-                or not _same_storage(param, flat_param)  # changed `.data`
+                or not _same_storage_as_data_ptr(
+                    param, flat_param_data_ptr
+                )  # changed `.data`
             )
             if param_changed:
                 # NOTE: The gradient is not preserved after a parameter change.
@@ -1737,10 +1739,6 @@
                 wroteback = True
 
             # Check for gradient writeback
-            # NOTE: Since this method is called in the pre-unshard, which is
-            # only called during computation in the pre-forward or
-            # pre-backward, the sharded gradient should be guaranteed to be in
-            # `.grad`, not in `._saved_grad_shard`.
             if param.grad is None and flat_param.grad is not None:
                 expected_shape = torch.Size([numel_in_shard])
                 self._writeback_tensor(
@@ -1750,13 +1748,11 @@
                 # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
                 # memory and owns the gradient storage, so it will never
                 # require gradient writeback.
-                flat_param_grad = (
-                    flat_param.grad
-                    if self.uses_sharded_strategy or not self._offload_params
-                    else flat_param._cpu_grad  # type: ignore[attr-defined]
-                )
-                needs_grad_writeback = flat_param_grad is None or not _same_storage(
-                    param.grad, flat_param_grad
+                needs_grad_writeback = (
+                    flat_param_grad is None
+                    or not _same_storage_as_data_ptr(
+                        param.grad, flat_param_grad_data_ptr
+                    )
                 )
                 if needs_grad_writeback:
                     if flat_param_grad is None:
@@ -1766,6 +1762,10 @@
                         param.grad, flat_param_grad, i, expected_shape, offset, False
                     )
                     flat_param.grad = flat_param_grad
+                    flat_param_grad = flat_param.grad
+                    flat_param_grad_data_ptr = (
+                        flat_param_grad.untyped_storage().data_ptr()
+                    )
             offset += numel_in_shard
         # TODO (awgu): Handle shared parameters. We need to re-generate the
         # shared parameter data structures in case sharedness changed.