[SPMD] Error out SPMD mode (#54454)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54454

According to the pitch in https://github.com/pytorch/pytorch/issues/47012

1. Let DDP error out if `device_ids` contains multiple devices.
2. If device_ids is not specified, DDP will use the provided model (module argument in DDP constructor) as-is, regardless if the model is on one GPU or multiple GPUs or on CPU.
3. Remove the assertion that prevents SPMD in DDP `join()` method, because now SPMD is already forbidden by the constructor. Also remove the relevant unit test `test_ddp_uneven_inputs_replicated_error`.

#Closes: https://github.com/pytorch/pytorch/issues/47012

ghstack-source-id: 125644392

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:distributed_gloo_spawn -- test_cuda
buck test mode/dev-nosan caffe2/test/distributed:distributed_gloo_spawn -- test_rnn

buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_nccl_backend_multi_device_ids_not_allowed
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_nccl_backend_single_device_module_device_ids_None
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_nccl_backend_multi_device_module_device_ids_None

buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_ddp_multi_device_module_config

waitforbuildbot

Reviewed By: pritamdamania87

Differential Revision: D27226092

fbshipit-source-id: 3ee1e4bc46e5e362fc82cf7a24b2fafb34fcf1b9
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index b4d4d38..0b48cf3 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -2265,18 +2265,19 @@
         gradient_as_bucket_view=False,
     ):
         model = Net()
+        device = devices[0] if devices else torch.device("cuda:%d" % self.rank)
         ddp_model = DistributedDataParallel(
-            copy.deepcopy(model).to(devices[0]),
+            copy.deepcopy(model).to(device),
             device_ids=device_ids,
             process_group=process_group,
             bucket_cap_mb=0.001,
             gradient_as_bucket_view=gradient_as_bucket_view,
         )
 
-        model.to(devices[0])
+        model.to(device)
 
-        input = torch.randn(global_batch_size, 2).to(devices[0])
-        target = torch.randn(global_batch_size, 4).to(devices[0])
+        input = torch.randn(global_batch_size, 2).to(device)
+        target = torch.randn(global_batch_size, 4).to(device)
 
         return model, ddp_model, input, target
 
@@ -2325,7 +2326,7 @@
         The `devices` argument is used to control placement of the model and
         must always be specified as list of `torch.Device` instances.
         """
-        local_batch_size = len(devices)
+        local_batch_size = 1 if devices is None else len(devices)
         global_batch_size = self.world_size * local_batch_size
 
         if multi_device:
@@ -2405,11 +2406,11 @@
 
     @requires_gloo()
     def test_gloo_backend_cpu_module(self):
-        self._test_gloo_backend([torch.device("cpu")], [])
+        self._test_gloo_backend([torch.device("cpu")], None)
 
     @requires_gloo()
     def test_gloo_backend_cpu_module_grad_is_view(self):
-        self._test_gloo_backend([torch.device("cpu")], [], gradient_as_bucket_view=True)
+        self._test_gloo_backend([torch.device("cpu")], None, gradient_as_bucket_view=True)
 
     @requires_gloo()
     @skip_if_lt_x_gpu(2)
@@ -2430,14 +2431,14 @@
     def test_gloo_backend_2gpu_module(self):
         int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
         devices = [torch.device("cuda:" + str(i)) for i in int_devices]
-        self._test_gloo_backend(devices, [], multi_device=True)
+        self._test_gloo_backend(devices, None, multi_device=True)
 
     @requires_gloo()
     @skip_if_lt_x_gpu(8)
     def test_gloo_backend_4gpu_module(self):
         int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
         devices = [torch.device("cuda:" + str(i)) for i in int_devices]
-        self._test_gloo_backend(devices, [], multi_device=True)
+        self._test_gloo_backend(devices, None, multi_device=True)
 
     def _test_nccl_backend(
         self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
@@ -2450,6 +2451,34 @@
 
     @requires_nccl()
     @skip_if_lt_x_gpu(2)
+    def test_nccl_backend_multi_device_ids_not_allowed(self):
+        int_devices = list(range(torch.cuda.device_count()))
+        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
+        with self.assertRaisesRegex(ValueError, "device_ids can only be None or contain a single element."):
+            self._test_nccl_backend(devices, int_devices)
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    def test_nccl_backend_single_device_module_device_ids_None(self):
+        self._test_nccl_backend(None, None)
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    def test_nccl_backend_single_device_module_empty_device_ids(self):
+        # This tests the backward compatibility of accepting an empty list as `device_ids`,
+        # although we no longer document this in favor of the default value of `None`,
+        # which is consistent with multi-device modules and CPU modules.
+        self._test_nccl_backend(None, [])
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(4)
+    def test_nccl_backend_multi_device_module_device_ids_None(self):
+        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
+        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
+        self._test_nccl_backend(devices, None, multi_device=True)
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
     def test_nccl_backend_1gpu_module_device_ids_integer_list(self):
         int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
         devices = [torch.device("cuda:" + str(i)) for i in int_devices]
@@ -2467,14 +2496,14 @@
     def test_nccl_backend_2gpu_module(self):
         int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
         devices = [torch.device("cuda:" + str(i)) for i in int_devices]
-        self._test_nccl_backend(devices, [], multi_device=True)
+        self._test_nccl_backend(devices, None, multi_device=True)
 
     @requires_nccl()
     @skip_if_lt_x_gpu(8)
     def test_nccl_backend_4gpu_module(self):
         int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
         devices = [torch.device("cuda:" + str(i)) for i in int_devices]
-        self._test_nccl_backend(devices, [], multi_device=True)
+        self._test_nccl_backend(devices, None, multi_device=True)
 
     @requires_nccl()
     @skip_if_lt_x_gpu(4)
@@ -2496,7 +2525,7 @@
                 model, output_device=gpus[1], process_group=process_group
             )
 
-        with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"):
+        with self.assertRaisesRegex(ValueError, "device_ids can only be None or contain a single element."):
             ddp_model = DistributedDataParallel(
                 model, device_ids=gpus, process_group=process_group
             )
@@ -2508,7 +2537,7 @@
             ddp_model = DistributedDataParallel(model, process_group=process_group)
 
         model = model.cpu()
-        with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"):
+        with self.assertRaisesRegex(ValueError, "device_ids can only be None or contain a single element."):
             ddp_model = DistributedDataParallel(
                 model, device_ids=gpus, process_group=process_group
             )
diff --git a/test/distributed/test_c10d_spawn.py b/test/distributed/test_c10d_spawn.py
index 433266c..363da4c 100644
--- a/test/distributed/test_c10d_spawn.py
+++ b/test/distributed/test_c10d_spawn.py
@@ -241,11 +241,7 @@
         store = c10d.FileStore(self.file.name, self.world_size)
         process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
         if inp[0].is_cuda:
-            num_gpus = torch.cuda.device_count()
-            batch_size = inp[0].size(0)
-            # batch_size must be evenly divisible by num_gpus_used, take the largest one
-            num_gpus_used = [i for i in range(1, num_gpus + 1) if batch_size % i == 0][-1]
-            device_ids = list(range(num_gpus_used))
+            device_ids = [torch.cuda.current_device()]
         else:
             device_ids = None
 
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 83b22c0..a2b4b34 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -19,7 +19,7 @@
 if torch.distributed.rpc.is_available():
     RPC_AVAILABLE = True
     from torch.distributed.rpc import RRef
-from torch._utils import _get_device_index, _get_all_device_indices
+from torch._utils import _get_device_index
 
 from ..modules import Module
 from ._functions import _get_stream
@@ -305,14 +305,18 @@
 
     Args:
         module (Module): module to be parallelized
-        device_ids (list of int or torch.device): CUDA devices. This should
-                   only be provided when the input module resides on a single
-                   CUDA device. For single-device modules, the i'th
-                   :attr:`module` replica is placed on ``device_ids[i]``. For
-                   multi-device modules and CPU modules, ``device_ids`` must be
-                   ``None`` or an empty list, and input data for the forward
-                   pass must be placed on the correct device. (default: all
-                   visible devices for single-device modules)
+        device_ids (list of int or torch.device): CUDA devices.
+                   1) For single-device modules, ``device_ids`` can
+                   contain exactly one device id, which represents the only
+                   CUDA device where the input module corresponding to this process resides.
+                   Alternatively, ``device_ids`` can also be ``None``.
+                   2) For multi-device modules and CPU modules,
+                   ``device_ids`` must be ``None``.
+
+                   When ``device_ids`` is ``None`` for both cases,
+                   both the input data for the forward pass and the actual module
+                   must be placed on the correct device.
+                   (default: ``None``)
         output_device (int or torch.device): Device location of output for
                       single-device CUDA modules. For multi-device modules and
                       CPU modules, it must be ``None``, and the module itself
@@ -388,6 +392,9 @@
             "doesn't have any parameter that requires a gradient."
         )
 
+        if device_ids is not None and len(device_ids) > 1:
+            raise ValueError("device_ids can only be None or contain a single element.")
+
         self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1
         distinct_device_types = {p.device.type for p in module.parameters()}
         assert len(distinct_device_types) == 1, (
@@ -396,7 +403,12 @@
         ).format(distinct_device_types)
         self.device_type = list(distinct_device_types)[0]
 
-        if self.device_type == "cpu" or self.is_multi_device_module:
+        if (
+            device_ids is None
+            or len(device_ids) == 0  # For backward compatibility.
+            or self.device_type == "cpu"
+            or self.is_multi_device_module
+        ):
             assert not device_ids and not output_device, (
                 "DistributedDataParallel device_ids and output_device arguments "
                 "only work with single-device GPU modules, but got "
@@ -406,10 +418,6 @@
             self.device_ids = None
             self.output_device = None
         else:
-            # Use all devices by default for single-device GPU modules
-            if device_ids is None:
-                device_ids = _get_all_device_indices()
-
             self.device_ids = [_get_device_index(x, True) for x in device_ids]
 
             if output_device is None:
@@ -967,11 +975,6 @@
         modifications to the model or data loading is required.
 
         .. warning::
-            This module works only with the multi-process, single-device usage
-            of :class:`torch.nn.parallel.DistributedDataParallel`,
-            which means that a single process works on a single GPU.
-
-        .. warning::
             This module currently does not support custom distributed collective
             operations in the forward pass, such as ``SyncBatchNorm`` or other
             custom defined collectives in the model's forward pass.
@@ -1030,12 +1033,6 @@
         # Log uneven input API usage.
         self.logger._set_uneven_input_join()
         try:
-            if self.device_ids and len(self.device_ids) > 1:
-                raise ValueError(
-                    """DDP join() API does not support Single-Process Multi-GPU
-                    mode training. The recommended approach for DDP training is
-                    to spawn a single process that works on a single GPU."""
-                )
             has_error = False
             self.ddp_uneven_inputs_config = _DDPUnevenInputsConfig(
                 ddp_join_enabled=enable,
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index c02c972..c25b1f4 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -4504,29 +4504,6 @@
                     loss = out.sum()
                     loss.backward()
 
-        @require_backend({"gloo", "nccl"})
-        @require_backends_available({"gloo", "nccl"})
-        @skip_if_lt_x_gpu(4)
-        def test_ddp_uneven_inputs_replicated_error(self):
-            # Tests that the context manager errors out in SPMD mode.
-            group = dist.new_group([0, 1])
-            if self.rank < 2:
-                model = nn.Linear(1, 1, bias=False)
-                rank_to_device = {0: [0, 1], 1: [2, 3]}
-
-                devices = rank_to_device[self.rank]
-                net = torch.nn.parallel.DistributedDataParallel(
-                    model.cuda(devices[0]), device_ids=devices, process_group=group
-                )
-                with self.assertRaisesRegex(
-                    ValueError, r"DDP join\(\) API does not support Single-Process Multi-GPU"
-                ):
-                    with net.join():
-                        pass
-            # We need a barrier since otherwise non-participating processes exit too early
-            # and cause a timeout.
-            self._barrier(timeout=60)
-
         @require_backend({"nccl", "gloo"})
         @require_n_gpus_for_nccl_backend(int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"])
         def test_broadcast_object_list(self):