[FSDP] Deduplicate `_orig_size` and `_unsharded_size` (#79984)

This removes the `_orig_size` attribute that is initialized in `fully_sharded_data_parallel.py` since it represents the same quantity as `_unsharded_size` in `flat_param.py`. Since the quantity is not sharding dependent, we keep its initialization in `FlatParameter.init_metadata()` instead of in `FullyShardedDataParallel._shard_parameters()`.

Differential Revision: [D37726062](https://our.internmc.facebook.com/intern/diff/D37726062)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79984
Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py
index b9170f6..3f90c6b 100644
--- a/torch/distributed/fsdp/_optim_utils.py
+++ b/torch/distributed/fsdp/_optim_utils.py
@@ -161,9 +161,7 @@
             dist._all_gather_base(tensor_buffer, value, group=group)
             torch.cuda.synchronize()
             if to_save:
-                assert hasattr(flat_param, "_orig_size"), \
-                    "Sharded flattened parameter should have `_orig_size` set"
-                unpadded_numel = flat_param._orig_size.numel()  # type: ignore[attr-defined]
+                unpadded_numel = flat_param._unsharded_size.numel()  # type: ignore[attr-defined]
                 tensor_state[state_name] = tensor_buffer[:unpadded_numel].cpu()
         # Zero-dimension tensor state and non-tensor state: take this rank's
         # value directly
@@ -468,7 +466,7 @@
         in zip(pos_dim_tensors, unflat_param_shapes)
     ]
     flat_tensor = torch.cat(tensors)
-    flat_param_shape = flat_param._orig_size  # type: ignore[attr-defined]
+    flat_param_shape = flat_param._unsharded_size  # type: ignore[attr-defined]
     assert flat_tensor.shape == flat_param_shape, \
         f"tensor optim state: {flat_tensor.shape} " \
         f"flattened parameter: {flat_param_shape}"
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index d5989d0..8cc7ff5 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -1510,7 +1510,6 @@
                 self.world_size > 1
                 and self.sharding_strategy != ShardingStrategy.NO_SHARD
             )
-            p._orig_size = p.size()  # type: ignore[attr-defined]
 
             if not p._is_sharded:  # type: ignore[attr-defined]
                 self.numel_padded_per_param.append(0)
@@ -1668,14 +1667,13 @@
     @torch.no_grad()
     def _init_param_attributes(self, p: Parameter) -> None:
         """
-        We manage several attributes on each Parameter instance. The first two
-        are set by :func:`_shard_parameters`:
+        We manage several attributes on each Parameter instance. The first is
+        set by :func:`_shard_parameters`:
             ``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
                 if the Parameter is intentionally not sharded (in which case we
                 will all-reduce grads for this param). Currently the way
                 `_is_sharded = False` is if world_size = 1 or sharding strategy
                 is NO_SHARD.
-            ``_orig_size``: the size of the original Parameter (before sharding)
         A few attributes are set here:
             ``_local_shard``: a single shard of the parameter. This is needed to
                 recover the shard after rebuilding full parameter in forward
@@ -1689,9 +1687,7 @@
             ``_shard_bwd_hook``: it holds the parameter's AccumulateGrad object
                 and the registered post hook handle.
         """
-        assert hasattr(p, "_is_sharded") and hasattr(
-            p, "_orig_size"
-        ), "Parameters should have been sharded during construction."
+        assert hasattr(p, "_is_sharded"), "Parameters should have been sharded during construction."
         # If _local_shard has been set in the first lazy init and
         # current parameter is pointed to _local_shard, no need to
         # set the _local_shard again.
@@ -3332,7 +3328,7 @@
         """
         p.data = output_tensor
         # Trim any padding and reshape to match original size.
-        p.data = p.data[: p._orig_size.numel()].view(p._orig_size)  # type: ignore[attr-defined]
+        p.data = p.data[:p._unsharded_size.numel()].view(p._unsharded_size)  # type: ignore[attr-defined]
 
     @torch.no_grad()
     def _rebuild_full_params(self) -> List[Tuple[torch.Tensor, bool]]:
@@ -3571,7 +3567,7 @@
         """Make sure p.grad has the correct size/device, otherwise set it to None."""
         for p in self.params:
             if p.grad is not None and (
-                p.grad.size() != p._orig_size  # type: ignore[attr-defined]
+                p.grad.size() != p._unsharded_size  # type: ignore[attr-defined]
                 or p.grad.device != p.device
             ):
                 offloaded: bool = p.grad.device != p.device