[FSDP] Fix `device_id` when buffer-only module (#103504)

There was an issue reported internally that with `sync_module_states=True`, if the model had buffers on CPU, even with `device_id` specified, FSDP would try to broadcast CPU buffers, leading to an error like:
```
RuntimeError: No backend type associated with device type cpu
```

After some investigation, I determined that we should _not_ fix this by moving the buffers to GPU just for the broadcast and then back to CPU. Instead, we should fix our `device_id` logic.

The issue is that we always used the _parameters_ as the proxy to tell whether we should move module states to the device specified by `device_id`. However, a module (often the root) may not have any parameters but have some buffers! In that case, the buffers are left on CPU even if `device_id` is specified. This PR fixes this by considering both parameters and buffers for movement to `device_id`.

Note that this PR preserves the logic that `ignored_modules` / `ignored_parameters` are not considered for this movement, meaning that ignored parameters are moved to `device_id`.

Note also that I had to move the unit test back from using MTPG to the normal PG since otherwise, I could not repro the original error. (It seems like MTPG does not complain if we try to use `dist._broadcast_coalesced()` with CPU tensors.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103504
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py
index 79bb1a4..a8d9b3b 100644
--- a/test/distributed/fsdp/test_fsdp_misc.py
+++ b/test/distributed/fsdp/test_fsdp_misc.py
@@ -290,6 +290,62 @@
         inp = fsdp_model.module.get_input(device=torch.device("cpu"))
         fsdp_model(*inp).sum().backward()
 
+    @skip_if_lt_x_gpu(2)
+    def test_cpu_init_with_sync_module_states(self):
+        """
+        Tests that passing ``sync_module_states=True`` raises an error for
+        a CPU module since the synchronization requires GPU communication,
+        while additionally passing ``device_id`` does not raise an error, even
+        when the model has CPU buffers.
+        """
+
+        def init_nested_wrapped_module():
+            return NestedWrappedModule.init(
+                self.process_group,
+                FSDPInitMode.NO_FSDP,
+                CUDAInitMode.CUDA_NEVER,
+            )
+
+        with self.assertRaisesRegex(
+            ValueError,
+            "The module has CPU parameters or buffers when `sync_module_states=True`",
+        ):
+            FSDP(
+                init_nested_wrapped_module(),
+                self.process_group,
+                sync_module_states=True,
+            )
+
+        # Check that `device_id` with `sync_module_states=True` works
+        nested_wrapped_module = init_nested_wrapped_module()
+        nested_wrapped_module.register_buffer(
+            "buf", torch.ones((2, 2), device="cpu") * self.rank
+        )
+        nested_wrapped_module.module[0].register_buffer(
+            "buf", torch.ones((3, 2), device="cpu") * self.rank
+        )
+        nested_wrapped_module = FSDP(
+            nested_wrapped_module,
+            self.process_group,
+            auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
+            device_id=torch.cuda.current_device(),
+            sync_module_states=True,
+        )
+        # Each rank's buffers should be 0s since rank 0 is the source, and they
+        # should be on GPU since we specified `device_id`
+        self.assertEqual(
+            nested_wrapped_module.buf.device,
+            torch.device("cuda", torch.cuda.current_device()),
+        )
+        self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
+        self.assertEqual(
+            nested_wrapped_module.module.module[0].buf.device,
+            torch.device("cuda", torch.cuda.current_device()),
+        )
+        self.assertEqual(
+            nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
+        )
+
 
 class TestFSDPMiscMultiThread(FSDPTestMultiThread):
     @property
@@ -477,29 +533,6 @@
             FSDP(no_params, device_id=0)
 
     @skip_if_lt_x_gpu(2)
-    def test_cpu_init_with_sync_module_states(self):
-        """Tests that passing ``sync_module_states=True`` raises an error for
-        a CPU module since the synchronization requires GPU communication,
-        while additionally passing ``device_id`` does not raise an error."""
-        nested_wrapped_module = NestedWrappedModule.init(
-            self.process_group,
-            FSDPInitMode.RECURSIVE,
-            CUDAInitMode.CUDA_NEVER,
-        )
-        with self.assertRaisesRegex(
-            ValueError, "The module has CPU parameters when `sync_module_states=True`"
-        ):
-            FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)
-
-        # Specifying device_id with sync_module_states=True works.
-        FSDP(
-            nested_wrapped_module,
-            self.process_group,
-            device_id=torch.cuda.current_device(),
-            sync_module_states=True,
-        )
-
-    @skip_if_lt_x_gpu(2)
     def test_fsdp_same_model_across_ranks(self):
         """
         FSDP broadcasts model from rank 0 to ensure it starts off with the same
diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py
index 0d11ab7..9cb95c0 100644
--- a/torch/distributed/fsdp/_init_utils.py
+++ b/torch/distributed/fsdp/_init_utils.py
@@ -595,7 +595,6 @@
         state.process_group,
         state._use_orig_params,
     )
-    # TODO: Can simplify call `shard()` in the `FlatParamHandle` ctor
     handle.shard()
     assert handle not in state._handles
     state.params.append(handle.flat_param)
@@ -854,32 +853,40 @@
 
     Precondition: ``_check_single_device_module()``.
     """
-    param = next(_get_orig_params(module, ignored_params), None)
-    if param is None:
-        return  # no original parameters to manage
     cpu_device = torch.device("cpu")
-    # TODO: This only checks the parameter's device, not any buffers. Thus, a
-    # buffer-only module will not get offloaded to CPU.
     if device_from_device_id is not None:
-        if param.device == cpu_device:
-            # BFS from `module` without traversing any nested FSDP instances to
-            # collect the parameters/buffers that have not yet been managed
-            queue: Deque[nn.Module] = collections.deque()
-            queue.append(module)
-            params: List[nn.Parameter] = []
-            buffers: List[torch.Tensor] = []
-            while queue:
-                curr_module = queue.popleft()
-                params.extend(curr_module.parameters(recurse=False))
-                buffers.extend(curr_module.buffers(recurse=False))
-                for submodule in curr_module.children():
-                    if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
-                        queue.append(submodule)
-            # NOTE: This includes moving ignored modules' parameters. If we
-            # decide to change the semantics in the future, simply filter based
-            # on the ignored parameters (and buffers).
-            _move_states_to_device(params, buffers, device_from_device_id)
-    elif param.device == cpu_device:
+        # BFS from `module` without traversing any nested FSDP instances to
+        # collect the parameters/buffers that have not yet been managed
+        queue: Deque[nn.Module] = collections.deque()
+        queue.append(module)
+        params: List[nn.Parameter] = []
+        buffers: List[torch.Tensor] = []
+        while queue:
+            curr_module = queue.popleft()
+            # NOTE: We include a check to only move parameters/buffers that are
+            # on CPU device. If they are on a CUDA device different from the
+            # one specified by `device_id`, then this does NOT move them. This
+            # is so that we can raise an error in `_get_compute_device()`.
+            params.extend(
+                param
+                for param in curr_module.parameters(recurse=False)
+                if param.device == cpu_device
+            )
+            buffers.extend(
+                buffer
+                for buffer in curr_module.buffers(recurse=False)
+                if buffer.device == cpu_device
+            )
+            for submodule in curr_module.children():
+                if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
+                    queue.append(submodule)
+        # NOTE: This includes moving ignored modules' parameters. If we
+        # decide to change the semantics in the future, simply filter based
+        # on the ignored parameters (and buffers).
+        _move_states_to_device(params, buffers, device_from_device_id)
+        return
+    param = next(_get_orig_params(module, ignored_params), None)
+    if param is not None and param.device == cpu_device:
         _warn_cpu_init()
 
 
@@ -973,7 +980,6 @@
     Precondition: ``sync_module_states == True`` and ``self.process_group`` has
     been set.
     """
-    _check_params_for_sync_module_states(params)
     module_states: List[torch.Tensor] = []
     for buffer in module.buffers():
         # Avoid re-synchronizing buffers in case of nested wrapping
@@ -981,6 +987,7 @@
             setattr(buffer, FSDP_SYNCED, True)
             module_states.append(buffer.detach())
     module_states.extend(param.detach() for param in params)
+    _check_module_states_for_sync_module_states(module_states)
     _sync_params_and_buffers(
         process_group,
         module_states,
@@ -994,12 +1001,12 @@
     buffers: List[torch.Tensor],
     process_group: dist.ProcessGroup,
 ) -> None:
-    _check_params_for_sync_module_states(params)
     # Assumes that each call to this method passes in disjoint `params` and
     # and `buffers` across calls, so there is no chance of re-synchronizing
     params_and_buffers = [param.detach() for param in params] + [
         buffer.detach() for buffer in buffers
     ]
+    _check_module_states_for_sync_module_states(params_and_buffers)
     _sync_params_and_buffers(
         process_group,
         params_and_buffers,
@@ -1008,15 +1015,16 @@
     )
 
 
-def _check_params_for_sync_module_states(
-    params: List[nn.Parameter],
+def _check_module_states_for_sync_module_states(
+    module_states: List[torch.Tensor],
 ) -> None:
-    if params and any(param.device == torch.device("cpu") for param in params):
+    if module_states and any(
+        tensor.device == torch.device("cpu") for tensor in module_states
+    ):
         raise ValueError(
-            "The module has CPU parameters when `sync_module_states=True`, "
-            "which only works when all parameters are on GPU. Please specify "
-            "the `device_id` argument or move the module to GPU before passing "
-            "into FSDP."
+            "The module has CPU parameters or buffers when `sync_module_states=True`, "
+            "which requires them to be on GPU. Please specify the `device_id` argument "
+            "or move the module to GPU before passing it to FSDP."
         )