[FSDP] Deduplicate `_orig_size` and `_unsharded_size` (#79984)
This removes the `_orig_size` attribute that is initialized in `fully_sharded_data_parallel.py` since it represents the same quantity as `_unsharded_size` in `flat_param.py`. Since the quantity is not sharding dependent, we keep its initialization in `FlatParameter.init_metadata()` instead of in `FullyShardedDataParallel._shard_parameters()`.
Differential Revision: [D37726062](https://our.internmc.facebook.com/intern/diff/D37726062)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79984
Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py
index b9170f6..3f90c6b 100644
--- a/torch/distributed/fsdp/_optim_utils.py
+++ b/torch/distributed/fsdp/_optim_utils.py
@@ -161,9 +161,7 @@
dist._all_gather_base(tensor_buffer, value, group=group)
torch.cuda.synchronize()
if to_save:
- assert hasattr(flat_param, "_orig_size"), \
- "Sharded flattened parameter should have `_orig_size` set"
- unpadded_numel = flat_param._orig_size.numel() # type: ignore[attr-defined]
+ unpadded_numel = flat_param._unsharded_size.numel() # type: ignore[attr-defined]
tensor_state[state_name] = tensor_buffer[:unpadded_numel].cpu()
# Zero-dimension tensor state and non-tensor state: take this rank's
# value directly
@@ -468,7 +466,7 @@
in zip(pos_dim_tensors, unflat_param_shapes)
]
flat_tensor = torch.cat(tensors)
- flat_param_shape = flat_param._orig_size # type: ignore[attr-defined]
+ flat_param_shape = flat_param._unsharded_size # type: ignore[attr-defined]
assert flat_tensor.shape == flat_param_shape, \
f"tensor optim state: {flat_tensor.shape} " \
f"flattened parameter: {flat_param_shape}"
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index d5989d0..8cc7ff5 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -1510,7 +1510,6 @@
self.world_size > 1
and self.sharding_strategy != ShardingStrategy.NO_SHARD
)
- p._orig_size = p.size() # type: ignore[attr-defined]
if not p._is_sharded: # type: ignore[attr-defined]
self.numel_padded_per_param.append(0)
@@ -1668,14 +1667,13 @@
@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:
"""
- We manage several attributes on each Parameter instance. The first two
- are set by :func:`_shard_parameters`:
+ We manage several attributes on each Parameter instance. The first is
+ set by :func:`_shard_parameters`:
``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param). Currently the way
`_is_sharded = False` is if world_size = 1 or sharding strategy
is NO_SHARD.
- ``_orig_size``: the size of the original Parameter (before sharding)
A few attributes are set here:
``_local_shard``: a single shard of the parameter. This is needed to
recover the shard after rebuilding full parameter in forward
@@ -1689,9 +1687,7 @@
``_shard_bwd_hook``: it holds the parameter's AccumulateGrad object
and the registered post hook handle.
"""
- assert hasattr(p, "_is_sharded") and hasattr(
- p, "_orig_size"
- ), "Parameters should have been sharded during construction."
+ assert hasattr(p, "_is_sharded"), "Parameters should have been sharded during construction."
# If _local_shard has been set in the first lazy init and
# current parameter is pointed to _local_shard, no need to
# set the _local_shard again.
@@ -3332,7 +3328,7 @@
"""
p.data = output_tensor
# Trim any padding and reshape to match original size.
- p.data = p.data[: p._orig_size.numel()].view(p._orig_size) # type: ignore[attr-defined]
+ p.data = p.data[:p._unsharded_size.numel()].view(p._unsharded_size) # type: ignore[attr-defined]
@torch.no_grad()
def _rebuild_full_params(self) -> List[Tuple[torch.Tensor, bool]]:
@@ -3571,7 +3567,7 @@
"""Make sure p.grad has the correct size/device, otherwise set it to None."""
for p in self.params:
if p.grad is not None and (
- p.grad.size() != p._orig_size # type: ignore[attr-defined]
+ p.grad.size() != p._unsharded_size # type: ignore[attr-defined]
or p.grad.device != p.device
):
offloaded: bool = p.grad.device != p.device