[Traceable FSDP2] Check hasattr('fsdp_pre_all_gather') only when not compile (#127855)

Dynamo doesn't support `hasattr(inner_tensor, "fsdp_post_all_gather")` yet. We will work on this support in Q3.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127855
Approved by: https://github.com/awgu
diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py
index e765496..f372fcd 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_common.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_common.py
@@ -6,6 +6,7 @@
 from typing import Any, cast, List, Optional, Tuple
 
 import torch
+import torch._dynamo.compiled_autograd as ca
 import torch.distributed as dist
 import torch.nn as nn
 from torch.distributed._composable.contract import _get_registry
@@ -122,7 +123,7 @@
     it avoids some CPU overhead by avoiding default args and not being differentiable.
     """
 
-    if not torch._dynamo.compiled_autograd.compiled_autograd_enabled:
+    if not ca.compiled_autograd_enabled:
         spec = DTensorSpec(
             device_mesh,
             placements,
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py
index 5fed53f..cf28a8e 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param.py
@@ -4,6 +4,7 @@
 from typing import Any, cast, List, Optional, Sequence, Tuple
 
 import torch
+import torch._dynamo.compiled_autograd as ca
 import torch.nn as nn
 
 from torch._prims_common import make_contiguous_strides_for
@@ -326,7 +327,9 @@
             self._extensions_data.clear()
             return
         inner_tensor = self._sharded_local_tensor
-        if hasattr(inner_tensor, "fsdp_post_all_gather"):
+        if not ca.compiled_autograd_enabled and hasattr(
+            inner_tensor, "fsdp_post_all_gather"
+        ):
             all_gather_outputs = self._unflatten_all_gather_outputs()
             (
                 unsharded_tensor,
@@ -496,7 +499,9 @@
     def all_gather_inputs(self) -> List[torch.Tensor]:  # 1D
         self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
         if self.sharded_state == ShardedState.SHARDED:
-            if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"):
+            if not ca.compiled_autograd_enabled and hasattr(
+                self._sharded_local_tensor, "fsdp_pre_all_gather"
+            ):
                 sharded_local_tensor = self._sharded_local_tensor
                 if self.offload_to_cpu:
                     sharded_local_tensor = sharded_local_tensor.to(
@@ -517,7 +522,9 @@
                 )
             return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
         elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
-            if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"):
+            if not ca.compiled_autograd_enabled and hasattr(
+                self._sharded_local_tensor, "fsdp_pre_all_gather"
+            ):
                 raise NotImplementedError
             all_gather_input = _to_dtype_if_needed(
                 cast(torch.Tensor, self._sharded_post_forward_param_data),