Use store based barrier in init_process_group. (#49419)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49419

As described in https://github.com/pytorch/pytorch/issues/48110, the
newly introduced `barrier()` in `init_process_group` messes up NCCL
communicator state since it uses a bunch of default devices to perform an
allreduce which simulates a barrier(). As a ressult, subsequent NCCL operations
might not behave as expected.
ghstack-source-id: 118861776

Test Plan:
1) unit test added.
2) waitforbuildbot

Reviewed By: mrshenli

Differential Revision: D25566550

fbshipit-source-id: ab083b67b634d7c515f4945deb228f959b27c936
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index 3fc6c93..970fdfa 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -4550,6 +4550,89 @@
         for root_rank in ranks:
             self._test_broadcast_coalesced(process_group, device, root_rank)
 
+    @requires_nccl()
+    @skip_if_lt_x_gpu(4)
+    def test_nccl_barrier(self):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        c10d.init_process_group(
+            backend="nccl",
+            rank=self.rank,
+            world_size=self.world_size,
+            store=store)
+
+        t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+        c10d.all_reduce(t)
+        expected_tensor = torch.tensor([3] * 10).cuda(2 * self.rank)
+        self.assertEqual(expected_tensor, t)
+
+        # Test with new_group
+        pg = c10d.new_group([0, 1])
+        t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+        pg.allreduce(t).wait()
+        self.assertEqual(expected_tensor, t)
+
+        pg = c10d.new_group([0])
+        if self.rank == 0:
+            t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+            expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+            pg.allreduce(t).wait()
+            self.assertEqual(expected_tensor, t)
+
+        pg = c10d.new_group([1])
+        if self.rank == 1:
+            t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+            expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+            pg.allreduce(t).wait()
+            self.assertEqual(expected_tensor, t)
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(4)
+    def test_nccl_barrier_timeout(self):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        if self.rank == 0:
+            with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+                c10d.init_process_group(
+                    backend="nccl",
+                    rank=self.rank,
+                    world_size=self.world_size,
+                    store=store,
+                    timeout=timedelta(seconds=1))
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(4)
+    def test_nccl_barrier_timeout_new_group(self):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        c10d.init_process_group(
+            backend="nccl",
+            rank=self.rank,
+            world_size=self.world_size,
+            store=store,
+            timeout=timedelta(seconds=1))
+
+        if self.rank == 0:
+            with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+                c10d.new_group([0, 1], timeout=timedelta(seconds=1))
+
+            with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+                c10d.new_group([0], timeout=timedelta(seconds=1))
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(4)
+    def test_nccl_barrier_timeout_new_group_non_member(self):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        c10d.init_process_group(
+            backend="nccl",
+            rank=self.rank,
+            world_size=self.world_size,
+            store=store,
+            timeout=timedelta(seconds=1))
+
+        if self.rank == 1:
+            with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+                c10d.new_group([0, 1], timeout=timedelta(seconds=1))
+
+            with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+                c10d.new_group([0], timeout=timedelta(seconds=1))
 
 if __name__ == "__main__":
     assert (
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 15d9a53..4185e2d 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -2,6 +2,7 @@
 import torch
 import warnings
 import contextlib
+import time
 from torch._six import string_classes
 from datetime import timedelta
 from typing import Dict, Optional, Tuple, Union
@@ -172,6 +173,26 @@
 # Process group count for default naming
 _group_count = 0
 
+STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
+
+def _store_based_barrier(rank, store, timeout):
+    """
+    Barrier based on store which is used for synchronizing processes after
+    ``init_process_group`` or ``new_group``. Intended to be used only with
+    those two methods and is not a generic alternative to ``barrier()``.
+    """
+    store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
+    store.add(store_key, 1)
+
+    # Now wait for all workers to check in with the store.
+    world_size = get_world_size()
+    worker_count = int(store.get(store_key))
+    start = time.time()
+    while worker_count != world_size:
+        time.sleep(0.01)
+        worker_count = int(store.get(store_key))
+        if timedelta(seconds=(time.time() - start)) > timeout:
+            raise RuntimeError("Timed out initializing process group")
 
 def _rank_not_in_group(group: ProcessGroup):
     """
@@ -475,7 +496,13 @@
     # barrier at the end to ensure that once we return from this method, all
     # process groups including global variables are updated correctly on all
     # ranks.
-    barrier()
+    if backend == Backend.MPI:
+        # MPI doesn't have store.
+        barrier()
+    else:
+        # Use store based barrier here since barrier() used a bunch of
+        # default devices and messes up NCCL internal state.
+        _store_based_barrier(rank, store, timeout)
 
 def _new_process_group_helper(world_size,
                               rank,
@@ -2452,6 +2479,12 @@
     # barrier at the end to ensure that once we return from this method, all
     # process groups including global variables are updated correctly on all
     # ranks.
-    barrier()
+    if backend == Backend.MPI:
+        # MPI doesn't have store.
+        barrier()
+    else:
+        # Use store based barrier here since barrier() used a bunch of
+        # default devices and messes up NCCL internal state.
+        _store_based_barrier(group_rank, default_store, timeout)
 
     return pg