[FSDP2] Simplified `_move_states_to_device` (#122907)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122907
Approved by: https://github.com/Skylion007
diff --git a/torch/distributed/_composable/fsdp/_fsdp_init.py b/torch/distributed/_composable/fsdp/_fsdp_init.py
index a0a33df..39b3d8d 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_init.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_init.py
@@ -118,7 +118,6 @@
params: List[nn.Parameter],
buffers: List[torch.Tensor],
device: torch.device,
- mesh_info: FSDPMeshInfo,
) -> None:
"""
We have FSDP move states to device for simpler and faster initialization
@@ -133,7 +132,7 @@
# Keep meta-device tensors on meta device for deferred init
continue
if isinstance(tensor, DTensor):
- if (dtensor_mesh_type := tensor._spec.mesh.device_type) != device.type:
+ if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type:
raise ValueError(
"Requires DTensor to have mesh of the same type as the FSDP mesh "
f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP"
diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py
index fa7066a..ec7a4a2 100644
--- a/torch/distributed/_composable/fsdp/fully_shard.py
+++ b/torch/distributed/_composable/fsdp/fully_shard.py
@@ -113,7 +113,7 @@
managed_modules = _get_managed_modules(module)
params, buffers = _get_managed_states(managed_modules)
- _move_states_to_device(params, buffers, device, mesh_info)
+ _move_states_to_device(params, buffers, device)
if params:
state._fsdp_param_group = FSDPParamGroup(
params, module, mesh_info, post_forward_mesh_info, device, mp_policy