[FSDP] Fix `device_id` when buffer-only module (#103504)
There was an issue reported internally that with `sync_module_states=True`, if the model had buffers on CPU, even with `device_id` specified, FSDP would try to broadcast CPU buffers, leading to an error like:
```
RuntimeError: No backend type associated with device type cpu
```
After some investigation, I determined that we should _not_ fix this by moving the buffers to GPU just for the broadcast and then back to CPU. Instead, we should fix our `device_id` logic.
The issue is that we always used the _parameters_ as the proxy to tell whether we should move module states to the device specified by `device_id`. However, a module (often the root) may not have any parameters but have some buffers! In that case, the buffers are left on CPU even if `device_id` is specified. This PR fixes this by considering both parameters and buffers for movement to `device_id`.
Note that this PR preserves the logic that `ignored_modules` / `ignored_parameters` are not considered for this movement, meaning that ignored parameters are moved to `device_id`.
Note also that I had to move the unit test back from using MTPG to the normal PG since otherwise, I could not repro the original error. (It seems like MTPG does not complain if we try to use `dist._broadcast_coalesced()` with CPU tensors.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103504
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py
index 79bb1a4..a8d9b3b 100644
--- a/test/distributed/fsdp/test_fsdp_misc.py
+++ b/test/distributed/fsdp/test_fsdp_misc.py
@@ -290,6 +290,62 @@
inp = fsdp_model.module.get_input(device=torch.device("cpu"))
fsdp_model(*inp).sum().backward()
+ @skip_if_lt_x_gpu(2)
+ def test_cpu_init_with_sync_module_states(self):
+ """
+ Tests that passing ``sync_module_states=True`` raises an error for
+ a CPU module since the synchronization requires GPU communication,
+ while additionally passing ``device_id`` does not raise an error, even
+ when the model has CPU buffers.
+ """
+
+ def init_nested_wrapped_module():
+ return NestedWrappedModule.init(
+ self.process_group,
+ FSDPInitMode.NO_FSDP,
+ CUDAInitMode.CUDA_NEVER,
+ )
+
+ with self.assertRaisesRegex(
+ ValueError,
+ "The module has CPU parameters or buffers when `sync_module_states=True`",
+ ):
+ FSDP(
+ init_nested_wrapped_module(),
+ self.process_group,
+ sync_module_states=True,
+ )
+
+ # Check that `device_id` with `sync_module_states=True` works
+ nested_wrapped_module = init_nested_wrapped_module()
+ nested_wrapped_module.register_buffer(
+ "buf", torch.ones((2, 2), device="cpu") * self.rank
+ )
+ nested_wrapped_module.module[0].register_buffer(
+ "buf", torch.ones((3, 2), device="cpu") * self.rank
+ )
+ nested_wrapped_module = FSDP(
+ nested_wrapped_module,
+ self.process_group,
+ auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
+ device_id=torch.cuda.current_device(),
+ sync_module_states=True,
+ )
+ # Each rank's buffers should be 0s since rank 0 is the source, and they
+ # should be on GPU since we specified `device_id`
+ self.assertEqual(
+ nested_wrapped_module.buf.device,
+ torch.device("cuda", torch.cuda.current_device()),
+ )
+ self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
+ self.assertEqual(
+ nested_wrapped_module.module.module[0].buf.device,
+ torch.device("cuda", torch.cuda.current_device()),
+ )
+ self.assertEqual(
+ nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
+ )
+
class TestFSDPMiscMultiThread(FSDPTestMultiThread):
@property
@@ -477,29 +533,6 @@
FSDP(no_params, device_id=0)
@skip_if_lt_x_gpu(2)
- def test_cpu_init_with_sync_module_states(self):
- """Tests that passing ``sync_module_states=True`` raises an error for
- a CPU module since the synchronization requires GPU communication,
- while additionally passing ``device_id`` does not raise an error."""
- nested_wrapped_module = NestedWrappedModule.init(
- self.process_group,
- FSDPInitMode.RECURSIVE,
- CUDAInitMode.CUDA_NEVER,
- )
- with self.assertRaisesRegex(
- ValueError, "The module has CPU parameters when `sync_module_states=True`"
- ):
- FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)
-
- # Specifying device_id with sync_module_states=True works.
- FSDP(
- nested_wrapped_module,
- self.process_group,
- device_id=torch.cuda.current_device(),
- sync_module_states=True,
- )
-
- @skip_if_lt_x_gpu(2)
def test_fsdp_same_model_across_ranks(self):
"""
FSDP broadcasts model from rank 0 to ensure it starts off with the same
diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py
index 0d11ab7..9cb95c0 100644
--- a/torch/distributed/fsdp/_init_utils.py
+++ b/torch/distributed/fsdp/_init_utils.py
@@ -595,7 +595,6 @@
state.process_group,
state._use_orig_params,
)
- # TODO: Can simplify call `shard()` in the `FlatParamHandle` ctor
handle.shard()
assert handle not in state._handles
state.params.append(handle.flat_param)
@@ -854,32 +853,40 @@
Precondition: ``_check_single_device_module()``.
"""
- param = next(_get_orig_params(module, ignored_params), None)
- if param is None:
- return # no original parameters to manage
cpu_device = torch.device("cpu")
- # TODO: This only checks the parameter's device, not any buffers. Thus, a
- # buffer-only module will not get offloaded to CPU.
if device_from_device_id is not None:
- if param.device == cpu_device:
- # BFS from `module` without traversing any nested FSDP instances to
- # collect the parameters/buffers that have not yet been managed
- queue: Deque[nn.Module] = collections.deque()
- queue.append(module)
- params: List[nn.Parameter] = []
- buffers: List[torch.Tensor] = []
- while queue:
- curr_module = queue.popleft()
- params.extend(curr_module.parameters(recurse=False))
- buffers.extend(curr_module.buffers(recurse=False))
- for submodule in curr_module.children():
- if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
- queue.append(submodule)
- # NOTE: This includes moving ignored modules' parameters. If we
- # decide to change the semantics in the future, simply filter based
- # on the ignored parameters (and buffers).
- _move_states_to_device(params, buffers, device_from_device_id)
- elif param.device == cpu_device:
+ # BFS from `module` without traversing any nested FSDP instances to
+ # collect the parameters/buffers that have not yet been managed
+ queue: Deque[nn.Module] = collections.deque()
+ queue.append(module)
+ params: List[nn.Parameter] = []
+ buffers: List[torch.Tensor] = []
+ while queue:
+ curr_module = queue.popleft()
+ # NOTE: We include a check to only move parameters/buffers that are
+ # on CPU device. If they are on a CUDA device different from the
+ # one specified by `device_id`, then this does NOT move them. This
+ # is so that we can raise an error in `_get_compute_device()`.
+ params.extend(
+ param
+ for param in curr_module.parameters(recurse=False)
+ if param.device == cpu_device
+ )
+ buffers.extend(
+ buffer
+ for buffer in curr_module.buffers(recurse=False)
+ if buffer.device == cpu_device
+ )
+ for submodule in curr_module.children():
+ if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
+ queue.append(submodule)
+ # NOTE: This includes moving ignored modules' parameters. If we
+ # decide to change the semantics in the future, simply filter based
+ # on the ignored parameters (and buffers).
+ _move_states_to_device(params, buffers, device_from_device_id)
+ return
+ param = next(_get_orig_params(module, ignored_params), None)
+ if param is not None and param.device == cpu_device:
_warn_cpu_init()
@@ -973,7 +980,6 @@
Precondition: ``sync_module_states == True`` and ``self.process_group`` has
been set.
"""
- _check_params_for_sync_module_states(params)
module_states: List[torch.Tensor] = []
for buffer in module.buffers():
# Avoid re-synchronizing buffers in case of nested wrapping
@@ -981,6 +987,7 @@
setattr(buffer, FSDP_SYNCED, True)
module_states.append(buffer.detach())
module_states.extend(param.detach() for param in params)
+ _check_module_states_for_sync_module_states(module_states)
_sync_params_and_buffers(
process_group,
module_states,
@@ -994,12 +1001,12 @@
buffers: List[torch.Tensor],
process_group: dist.ProcessGroup,
) -> None:
- _check_params_for_sync_module_states(params)
# Assumes that each call to this method passes in disjoint `params` and
# and `buffers` across calls, so there is no chance of re-synchronizing
params_and_buffers = [param.detach() for param in params] + [
buffer.detach() for buffer in buffers
]
+ _check_module_states_for_sync_module_states(params_and_buffers)
_sync_params_and_buffers(
process_group,
params_and_buffers,
@@ -1008,15 +1015,16 @@
)
-def _check_params_for_sync_module_states(
- params: List[nn.Parameter],
+def _check_module_states_for_sync_module_states(
+ module_states: List[torch.Tensor],
) -> None:
- if params and any(param.device == torch.device("cpu") for param in params):
+ if module_states and any(
+ tensor.device == torch.device("cpu") for tensor in module_states
+ ):
raise ValueError(
- "The module has CPU parameters when `sync_module_states=True`, "
- "which only works when all parameters are on GPU. Please specify "
- "the `device_id` argument or move the module to GPU before passing "
- "into FSDP."
+ "The module has CPU parameters or buffers when `sync_module_states=True`, "
+ "which requires them to be on GPU. Please specify the `device_id` argument "
+ "or move the module to GPU before passing it to FSDP."
)