Use store based barrier in init_process_group. (#49419)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49419
As described in https://github.com/pytorch/pytorch/issues/48110, the
newly introduced `barrier()` in `init_process_group` messes up NCCL
communicator state since it uses a bunch of default devices to perform an
allreduce which simulates a barrier(). As a ressult, subsequent NCCL operations
might not behave as expected.
ghstack-source-id: 118861776
Test Plan:
1) unit test added.
2) waitforbuildbot
Reviewed By: mrshenli
Differential Revision: D25566550
fbshipit-source-id: ab083b67b634d7c515f4945deb228f959b27c936
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index 3fc6c93..970fdfa 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -4550,6 +4550,89 @@
for root_rank in ranks:
self._test_broadcast_coalesced(process_group, device, root_rank)
+ @requires_nccl()
+ @skip_if_lt_x_gpu(4)
+ def test_nccl_barrier(self):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ c10d.init_process_group(
+ backend="nccl",
+ rank=self.rank,
+ world_size=self.world_size,
+ store=store)
+
+ t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+ c10d.all_reduce(t)
+ expected_tensor = torch.tensor([3] * 10).cuda(2 * self.rank)
+ self.assertEqual(expected_tensor, t)
+
+ # Test with new_group
+ pg = c10d.new_group([0, 1])
+ t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+ pg.allreduce(t).wait()
+ self.assertEqual(expected_tensor, t)
+
+ pg = c10d.new_group([0])
+ if self.rank == 0:
+ t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+ expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+ pg.allreduce(t).wait()
+ self.assertEqual(expected_tensor, t)
+
+ pg = c10d.new_group([1])
+ if self.rank == 1:
+ t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+ expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
+ pg.allreduce(t).wait()
+ self.assertEqual(expected_tensor, t)
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(4)
+ def test_nccl_barrier_timeout(self):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ if self.rank == 0:
+ with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+ c10d.init_process_group(
+ backend="nccl",
+ rank=self.rank,
+ world_size=self.world_size,
+ store=store,
+ timeout=timedelta(seconds=1))
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(4)
+ def test_nccl_barrier_timeout_new_group(self):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ c10d.init_process_group(
+ backend="nccl",
+ rank=self.rank,
+ world_size=self.world_size,
+ store=store,
+ timeout=timedelta(seconds=1))
+
+ if self.rank == 0:
+ with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+ c10d.new_group([0, 1], timeout=timedelta(seconds=1))
+
+ with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+ c10d.new_group([0], timeout=timedelta(seconds=1))
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(4)
+ def test_nccl_barrier_timeout_new_group_non_member(self):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ c10d.init_process_group(
+ backend="nccl",
+ rank=self.rank,
+ world_size=self.world_size,
+ store=store,
+ timeout=timedelta(seconds=1))
+
+ if self.rank == 1:
+ with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+ c10d.new_group([0, 1], timeout=timedelta(seconds=1))
+
+ with self.assertRaisesRegex(RuntimeError, "Timed out initializing process group"):
+ c10d.new_group([0], timeout=timedelta(seconds=1))
if __name__ == "__main__":
assert (
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 15d9a53..4185e2d 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -2,6 +2,7 @@
import torch
import warnings
import contextlib
+import time
from torch._six import string_classes
from datetime import timedelta
from typing import Dict, Optional, Tuple, Union
@@ -172,6 +173,26 @@
# Process group count for default naming
_group_count = 0
+STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
+
+def _store_based_barrier(rank, store, timeout):
+ """
+ Barrier based on store which is used for synchronizing processes after
+ ``init_process_group`` or ``new_group``. Intended to be used only with
+ those two methods and is not a generic alternative to ``barrier()``.
+ """
+ store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
+ store.add(store_key, 1)
+
+ # Now wait for all workers to check in with the store.
+ world_size = get_world_size()
+ worker_count = int(store.get(store_key))
+ start = time.time()
+ while worker_count != world_size:
+ time.sleep(0.01)
+ worker_count = int(store.get(store_key))
+ if timedelta(seconds=(time.time() - start)) > timeout:
+ raise RuntimeError("Timed out initializing process group")
def _rank_not_in_group(group: ProcessGroup):
"""
@@ -475,7 +496,13 @@
# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
- barrier()
+ if backend == Backend.MPI:
+ # MPI doesn't have store.
+ barrier()
+ else:
+ # Use store based barrier here since barrier() used a bunch of
+ # default devices and messes up NCCL internal state.
+ _store_based_barrier(rank, store, timeout)
def _new_process_group_helper(world_size,
rank,
@@ -2452,6 +2479,12 @@
# barrier at the end to ensure that once we return from this method, all
# process groups including global variables are updated correctly on all
# ranks.
- barrier()
+ if backend == Backend.MPI:
+ # MPI doesn't have store.
+ barrier()
+ else:
+ # Use store based barrier here since barrier() used a bunch of
+ # default devices and messes up NCCL internal state.
+ _store_based_barrier(group_rank, default_store, timeout)
return pg