[Traceable FSDP2] Check hasattr('fsdp_pre_all_gather') only when not compile (#127855)
Dynamo doesn't support `hasattr(inner_tensor, "fsdp_post_all_gather")` yet. We will work on this support in Q3.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127855
Approved by: https://github.com/awgu
diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py
index e765496..f372fcd 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_common.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_common.py
@@ -6,6 +6,7 @@
from typing import Any, cast, List, Optional, Tuple
import torch
+import torch._dynamo.compiled_autograd as ca
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.contract import _get_registry
@@ -122,7 +123,7 @@
it avoids some CPU overhead by avoiding default args and not being differentiable.
"""
- if not torch._dynamo.compiled_autograd.compiled_autograd_enabled:
+ if not ca.compiled_autograd_enabled:
spec = DTensorSpec(
device_mesh,
placements,
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py
index 5fed53f..cf28a8e 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param.py
@@ -4,6 +4,7 @@
from typing import Any, cast, List, Optional, Sequence, Tuple
import torch
+import torch._dynamo.compiled_autograd as ca
import torch.nn as nn
from torch._prims_common import make_contiguous_strides_for
@@ -326,7 +327,9 @@
self._extensions_data.clear()
return
inner_tensor = self._sharded_local_tensor
- if hasattr(inner_tensor, "fsdp_post_all_gather"):
+ if not ca.compiled_autograd_enabled and hasattr(
+ inner_tensor, "fsdp_post_all_gather"
+ ):
all_gather_outputs = self._unflatten_all_gather_outputs()
(
unsharded_tensor,
@@ -496,7 +499,9 @@
def all_gather_inputs(self) -> List[torch.Tensor]: # 1D
self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD)
if self.sharded_state == ShardedState.SHARDED:
- if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"):
+ if not ca.compiled_autograd_enabled and hasattr(
+ self._sharded_local_tensor, "fsdp_pre_all_gather"
+ ):
sharded_local_tensor = self._sharded_local_tensor
if self.offload_to_cpu:
sharded_local_tensor = sharded_local_tensor.to(
@@ -517,7 +522,9 @@
)
return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
- if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"):
+ if not ca.compiled_autograd_enabled and hasattr(
+ self._sharded_local_tensor, "fsdp_pre_all_gather"
+ ):
raise NotImplementedError
all_gather_input = _to_dtype_if_needed(
cast(torch.Tensor, self._sharded_post_forward_param_data),