[Traceable FSDP2] Make FSDPParam._unsharded_param creation traceable (#127245)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127245
Approved by: https://github.com/awgu
diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index 99b69cd..ac50848 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -1,6 +1,7 @@
from typing import List, NamedTuple, Optional, Tuple, Union
import torch
+import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
from torch.distributed._tensor import DTensor
from torch.distributed.distributed_c10d import ReduceOp
@@ -102,10 +103,21 @@
for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
):
- fsdp_param.init_all_gather_outputs(
- all_gather_input_numels, all_gather_input_dtypes, world_size, device
- ) # no-op after 1st call
- fsdp_param.alloc_all_gather_outputs()
+ if ca.compiled_autograd_enabled:
+ fsdp_param.init_all_gather_outputs(
+ all_gather_input_numels,
+ all_gather_input_dtypes,
+ world_size,
+ device,
+ # NOTE: Under compile, make sure we always recreate all_gather_outputs
+ # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2].
+ force_recreate=True,
+ )
+ else:
+ fsdp_param.init_all_gather_outputs(
+ all_gather_input_numels, all_gather_input_dtypes, world_size, device
+ ) # no-op after 1st call
+ fsdp_param.alloc_all_gather_outputs()
all_gather_output = all_gather_output.view(world_size, -1)
gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs)
if all_gather_output.dtype == torch.uint8:
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py
index 81596fe..c56dc79 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param.py
@@ -240,6 +240,7 @@
)
param_data = param
self._orig_size = param_data.size()
+ self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
shard_rank = self.mesh_info.shard_mesh_rank
shard_world_size = self.mesh_info.shard_mesh_size
chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
@@ -311,8 +312,9 @@
all_gather_input_dtypes: List[torch.dtype],
world_size: int,
device: torch.device,
+ force_recreate: bool = False,
):
- if self.all_gather_outputs:
+ if not force_recreate and len(self.all_gather_outputs) > 0:
return # already initialized
self.all_gather_outputs = [
torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
@@ -320,7 +322,24 @@
]
def init_unsharded_param(self):
- if hasattr(self, "_unsharded_param"): # after the 1st all-gather
+ """
+ [Note: Invariants for torch.compile Traceable FSDP2]
+ 1. Under compile, we always re-populate the content of `self._unsharded_param`
+ per AllGather using the slow path.
+ 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather.
+ This is to ensure the buffer creation is internal to the graph and
+ avoid `self.all_gather_outputs` being captured as a graph input.
+ 3. Under compile, at the end of `free_unsharded_param()`, we always clean up
+ `self.all_gather_outputs` and `self._unsharded_inner_tensors`,
+ to avoid them being captured as graph output.
+
+ With these invariants, only these tensors will be inputs to the graph:
+ - Sharded parameters
+ - Placeholders for the `self._unsharded_param` nn.Parameter
+ """
+ if not ca.compiled_autograd_enabled and hasattr(
+ self, "_unsharded_param"
+ ): # after the 1st all-gather
inner_tensor = self._sharded_local_tensor
if not hasattr(inner_tensor, "fsdp_post_all_gather"):
return # already initialized
@@ -357,13 +376,20 @@
unsharded_param = torch.as_strided(
unsharded_tensor,
self._orig_size,
- make_contiguous_strides_for(self._orig_size),
+ self._contiguous_orig_stride,
storage_offset=0,
)
if self.is_dtensor:
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
- self._unsharded_param = nn.Parameter(unsharded_param)
- self._unsharded_param.requires_grad_(self.sharded_param.requires_grad)
+ if hasattr(self, "_unsharded_param"):
+ assert ca.compiled_autograd_enabled
+ with torch.no_grad():
+ alloc_storage(self._unsharded_param)
+ self._unsharded_param.copy_(unsharded_param)
+ else:
+ self._unsharded_param = nn.Parameter(
+ unsharded_param, requires_grad=self.sharded_param.requires_grad
+ )
def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]:
return tuple(
@@ -493,6 +519,9 @@
self.all_gather_outputs, self._unsharded_inner_tensors
):
free_storage(tensor)
+ if ca.compiled_autograd_enabled:
+ self.all_gather_outputs = []
+ self._unsharded_inner_tensors = []
@property
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D