[Traceable FSDP2] Make FSDPParam._unsharded_param creation traceable (#127245)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127245
Approved by: https://github.com/awgu
diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index 99b69cd..ac50848 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -1,6 +1,7 @@
 from typing import List, NamedTuple, Optional, Tuple, Union
 
 import torch
+import torch._dynamo.compiled_autograd as ca
 import torch.distributed as dist
 from torch.distributed._tensor import DTensor
 from torch.distributed.distributed_c10d import ReduceOp
@@ -102,10 +103,21 @@
     for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip(
         param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params
     ):
-        fsdp_param.init_all_gather_outputs(
-            all_gather_input_numels, all_gather_input_dtypes, world_size, device
-        )  # no-op after 1st call
-        fsdp_param.alloc_all_gather_outputs()
+        if ca.compiled_autograd_enabled:
+            fsdp_param.init_all_gather_outputs(
+                all_gather_input_numels,
+                all_gather_input_dtypes,
+                world_size,
+                device,
+                # NOTE: Under compile, make sure we always recreate all_gather_outputs
+                # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2].
+                force_recreate=True,
+            )
+        else:
+            fsdp_param.init_all_gather_outputs(
+                all_gather_input_numels, all_gather_input_dtypes, world_size, device
+            )  # no-op after 1st call
+            fsdp_param.alloc_all_gather_outputs()
     all_gather_output = all_gather_output.view(world_size, -1)
     gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs)
     if all_gather_output.dtype == torch.uint8:
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py
index 81596fe..c56dc79 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param.py
@@ -240,6 +240,7 @@
             )
             param_data = param
         self._orig_size = param_data.size()
+        self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
         shard_rank = self.mesh_info.shard_mesh_rank
         shard_world_size = self.mesh_info.shard_mesh_size
         chunks = _chunk_with_empty(param_data, shard_world_size, dim=0)
@@ -311,8 +312,9 @@
         all_gather_input_dtypes: List[torch.dtype],
         world_size: int,
         device: torch.device,
+        force_recreate: bool = False,
     ):
-        if self.all_gather_outputs:
+        if not force_recreate and len(self.all_gather_outputs) > 0:
             return  # already initialized
         self.all_gather_outputs = [
             torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device)
@@ -320,7 +322,24 @@
         ]
 
     def init_unsharded_param(self):
-        if hasattr(self, "_unsharded_param"):  # after the 1st all-gather
+        """
+        [Note: Invariants for torch.compile Traceable FSDP2]
+        1. Under compile, we always re-populate the content of `self._unsharded_param`
+           per AllGather using the slow path.
+        2. Under compile, we always recreate `self.all_gather_outputs` per AllGather.
+           This is to ensure the buffer creation is internal to the graph and
+           avoid `self.all_gather_outputs` being captured as a graph input.
+        3. Under compile, at the end of `free_unsharded_param()`, we always clean up
+           `self.all_gather_outputs` and `self._unsharded_inner_tensors`,
+           to avoid them being captured as graph output.
+
+        With these invariants, only these tensors will be inputs to the graph:
+        - Sharded parameters
+        - Placeholders for the `self._unsharded_param` nn.Parameter
+        """
+        if not ca.compiled_autograd_enabled and hasattr(
+            self, "_unsharded_param"
+        ):  # after the 1st all-gather
             inner_tensor = self._sharded_local_tensor
             if not hasattr(inner_tensor, "fsdp_post_all_gather"):
                 return  # already initialized
@@ -357,13 +376,20 @@
         unsharded_param = torch.as_strided(
             unsharded_tensor,
             self._orig_size,
-            make_contiguous_strides_for(self._orig_size),
+            self._contiguous_orig_stride,
             storage_offset=0,
         )
         if self.is_dtensor:
             unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
-        self._unsharded_param = nn.Parameter(unsharded_param)
-        self._unsharded_param.requires_grad_(self.sharded_param.requires_grad)
+        if hasattr(self, "_unsharded_param"):
+            assert ca.compiled_autograd_enabled
+            with torch.no_grad():
+                alloc_storage(self._unsharded_param)
+                self._unsharded_param.copy_(unsharded_param)
+        else:
+            self._unsharded_param = nn.Parameter(
+                unsharded_param, requires_grad=self.sharded_param.requires_grad
+            )
 
     def _unflatten_all_gather_outputs(self) -> Tuple[torch.Tensor, ...]:
         return tuple(
@@ -493,6 +519,9 @@
             self.all_gather_outputs, self._unsharded_inner_tensors
         ):
             free_storage(tensor)
+        if ca.compiled_autograd_enabled:
+            self.all_gather_outputs = []
+            self._unsharded_inner_tensors = []
 
     @property
     def all_gather_inputs(self) -> List[torch.Tensor]:  # 1D