[c10d] Move pg wrapper tests to their own file. (#59840)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59840
moving these tests to their own standalone file. No meaningful code changes.
ghstack-source-id: 131359162
Test Plan: CI
Reviewed By: cbalioglu
Differential Revision: D29012664
fbshipit-source-id: 348870016509a6ed7e69240fa82bccef4a12d674
diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh
index daa0fb8..7ed43a9 100755
--- a/.jenkins/pytorch/multigpu-test.sh
+++ b/.jenkins/pytorch/multigpu-test.sh
@@ -26,6 +26,7 @@
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_store
+time python test/run_test.py --verbose -i distributed/test_pg_wrapper
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_process_group_agent
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
assert_git_not_dirty
diff --git a/.jenkins/pytorch/win-test-helpers/test_distributed.bat b/.jenkins/pytorch/win-test-helpers/test_distributed.bat
index a50c153..53ebee8 100644
--- a/.jenkins/pytorch/win-test-helpers/test_distributed.bat
+++ b/.jenkins/pytorch/win-test-helpers/test_distributed.bat
@@ -22,3 +22,6 @@
%1\python.exe test/run_test.py --verbose -i distributed/test_store
if %errorlevel% neq 0 ( exit /b %errorlevel% )
+
+%1\python.exe test/run_test.py --verbose -i distributed/test_pg_wrapper
+if %errorlevel% neq 0 ( exit /b %errorlevel% )
diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py
index 4d2dc0e..a25ef9a 100644
--- a/test/distributed/test_c10d_common.py
+++ b/test/distributed/test_c10d_common.py
@@ -233,125 +233,6 @@
return F.softmax(self.embedding(x), dim=1)
-class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
- def setUp(self):
- super(AbstractProcessGroupWrapperTest, self).setUp()
- # For Windows platform, Python does not support fork, change it to spawn here.
- if sys.platform == "win32":
- self._spawn_processes()
- else:
- self._fork_processes()
-
- def _test_collective_hang(self, wrapper_pg, use_cuda=False):
- # All ranks besides 1 call allreduce and wrapper_pg should detect a hang
- # and report an issue with rank 1.
- faulty_rank = 1
- if self.rank != faulty_rank:
- tensor = torch.randn(20, 10)
- if use_cuda:
- tensor = tensor.to(self.rank)
-
- if self.rank == 0:
- # Rank 0 reports faulty ranks
- err = f"Ranks {faulty_rank} failed to pass monitoredBarrier"
- else:
- err = "Please check rank 0 logs for faulty rank"
- with self.assertRaisesRegex(RuntimeError, err):
- wrapper_pg.allreduce([tensor])
-
- def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
- tensor = torch.randn(20, 10)
- if use_cuda:
- tensor = tensor.to(self.rank)
- works = []
- # Run a few successful collectives
- for _ in range(10):
- work = wrapper_pg.allreduce([tensor])
- works.append(work)
-
- for w in works:
- w.wait()
-
- # Simulate mismatch: allreduce vs reduce.
- with self.assertRaisesRegex(
- RuntimeError, "Mismatch between collective operation types"
- ):
- if self.rank == 0:
- wrapper_pg.allreduce([tensor])
- else:
- wrapper_pg.reduce([tensor])
-
- # Check additional mismatches
-
- with self.assertRaisesRegex(
- RuntimeError, "Mismatch between collective operation types"
- ):
- if self.rank == 0:
- wrapper_pg.reduce([tensor])
- else:
- wrapper_pg.barrier()
-
- with self.assertRaisesRegex(
- RuntimeError, "Mismatch between collective operation types"
- ):
- scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
- scattered_tensor = torch.empty(4)
- if self.rank == 0:
- wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
- else:
- wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
-
- with self.assertRaisesRegex(
- RuntimeError, "Mismatch between collective operation types"
- ):
- if self.rank == 0:
- wrapper_pg.broadcast(tensor, 0)
- else:
- output_tensors = [
- torch.zeros_like(tensor) for _ in range(self.world_size)
- ]
- wrapper_pg.allgather([output_tensors], [tensor])
-
- def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False):
- wrapper_pg.barrier()
- dim = 2 if self.rank == 0 else 10
- tensor = torch.randn(20, dim)
- if use_cuda:
- tensor = tensor.to(self.rank)
- with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
- wrapper_pg.allreduce([tensor])
- # Check errors are raised when dimensionality of shapes is different
- tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10)
- if use_cuda:
- tensor = tensor.to(self.rank)
- with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
- wrapper_pg.allreduce([tensor])
-
- # Check shape errors with scatter
- input = [
- torch.tensor(
- [self.rank] if self.rank == 0 else [self.rank, self.rank],
- device=self.rank if use_cuda else "cpu",
- )
- for _ in range(self.world_size)
- ]
- outputs = [
- torch.tensor(
- [-1] if self.rank == 0 else [-1, -1],
- device=self.rank if use_cuda else "cpu",
- )
- for _ in range(self.world_size)
- ]
- root_rank = 0
- opts = c10d.ScatterOptions()
- opts.rootRank = root_rank
- with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
- if self.rank == root_rank:
- wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait()
- else:
- wrapper_pg.scatter([outputs[self.rank]], [], opts).wait()
-
-
class AbstractDistributedDataParallelTest(object):
def tearDown(self):
# DistributedDataParallel test doesn't seem to call FileStore destructor
diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py
index 8ce6443..a4c2b85 100644
--- a/test/distributed/test_c10d_gloo.py
+++ b/test/distributed/test_c10d_gloo.py
@@ -29,7 +29,6 @@
simple_sparse_reduce_tests,
skip_if_win32,
create_device,
- with_dist_debug_levels,
verify_ddp_error_logged,
)
from torch.testing._internal.common_utils import (
@@ -45,7 +44,6 @@
Task,
ModuleForDdpCommHook,
SparseGradientModule,
- AbstractProcessGroupWrapperTest,
)
@@ -208,92 +206,6 @@
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
-class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
- def setUp(self):
- super(ProcessGroupGlooWrapperTest, self).setUp()
-
- def opts(self, threads=2, timeout=10.0):
- opts = c10d.ProcessGroupGloo._Options()
- opts._timeout = timeout
- opts._devices = [create_device(interface=LOOPBACK)]
- opts._threads = threads
- return opts
-
- def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
- store = c10d.FileStore(self.file_name, self.world_size)
- c10d.init_process_group(
- backend="gloo", rank=self.rank, world_size=self.world_size, store=store
- )
- if with_new_group:
- pg = c10d.new_group(backend="gloo")
- else:
- _pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(timeout=timeout))
- pg = c10d._create_process_group_wrapper(
- _pg,
- "unused",
- store,
- self.rank,
- self.world_size,
- timeout=timeout,
- )
- return pg
-
- def test_collective_hang(self):
- pg = self._create_wrapper_pg(timeout=2.0)
- self._test_collective_hang(pg)
-
- # NOTE: these tests are separated by debug level instead of combined into
- # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
- # combined after that is resolved.
- @with_dist_debug_levels(levels=["DETAIL"])
- def test_collectives_op_mismatch_debug_mode(self):
- pg = self._create_wrapper_pg(with_new_group=True)
- self._test_collectives_op_mismatch(pg)
-
- @with_dist_debug_levels(levels=["OFF"])
- def test_collectives_op_mismatch(self):
- pg = self._create_wrapper_pg(with_new_group=False)
- self._test_collectives_op_mismatch(pg)
-
- @with_dist_debug_levels(levels=["DETAIL"])
- def test_collective_shape_mismatch_debug_mode(self):
- pg = self._create_wrapper_pg(with_new_group=True)
- self._test_collective_shape_mismatch(pg)
-
- @with_dist_debug_levels(levels=["OFF"])
- def test_collective_shape_mismatch(self):
- pg = self._create_wrapper_pg(with_new_group=False)
- self._test_collective_shape_mismatch(pg)
-
- @skip_if_lt_x_gpu(4)
- @with_dist_debug_levels(levels=["DETAIL"])
- def test_collectives_op_mismatch_cuda_debug_mode(self):
- pg = self._create_wrapper_pg(with_new_group=True)
- self._test_collectives_op_mismatch(pg, use_cuda=True)
-
- @skip_if_lt_x_gpu(4)
- @with_dist_debug_levels(levels=["OFF"])
- def test_collectives_op_mismatch_cuda(self):
- pg = self._create_wrapper_pg(with_new_group=False)
- self._test_collectives_op_mismatch(pg, use_cuda=True)
-
- @skip_if_lt_x_gpu(4)
- @with_dist_debug_levels(levels=["DETAIL"])
- def test_collective_shape_mismatch_cuda_debug_mode(self):
- pg = self._create_wrapper_pg(with_new_group=True)
- self._test_collective_shape_mismatch(pg, use_cuda=True)
-
- @skip_if_lt_x_gpu(4)
- @with_dist_debug_levels(levels=["OFF"])
- def test_collective_shape_mismatch_cuda(self):
- pg = self._create_wrapper_pg(with_new_group=False)
- self._test_collective_shape_mismatch(pg, use_cuda=True)
-
-@requires_gloo()
-@unittest.skipIf(
- TEST_WITH_TSAN,
- "TSAN is not fork-safe since we're forking in a multi-threaded environment",
-)
class ProcessGroupGlooTest(MultiProcessTestCase):
def setUp(self):
super(ProcessGroupGlooTest, self).setUp()
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 244b506..5583cbb 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -30,7 +30,6 @@
from torch.utils.checkpoint import checkpoint
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
- requires_gloo,
requires_nccl,
requires_nccl_version,
skip_if_lt_x_gpu,
@@ -46,7 +45,7 @@
TEST_WITH_TSAN,
)
import test_c10d_common
-from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook, AbstractProcessGroupWrapperTest
+from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook
class RendezvousEnvTest(TestCase):
@@ -159,88 +158,6 @@
raise unittest.SkipTest("No GPUs available, skipping test")
self._test_default_store_timeout("nccl")
-@requires_gloo()
-@requires_nccl()
-@unittest.skipIf(
- TEST_WITH_TSAN,
- "TSAN is not fork-safe since we're forking in a multi-threaded environment",
-)
-class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest):
- def setUp(self):
- self.num_gpus = torch.cuda.device_count()
- if self.num_gpus < 2:
- raise unittest.SkipTest("NCCL test requires 2+ GPUs")
- super(AbstractProcessGroupWrapperTest, self).setUp()
- self._spawn_processes()
- # NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
- # that use NCCL_BLOCKING_WAIT will test it as expected.
- os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
-
- @property
- def world_size(self) -> int:
- return 2
-
- def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
- store = c10d.FileStore(self.file_name, self.world_size)
- c10d.init_process_group(
- backend="nccl",
- rank=self.rank,
- world_size=self.world_size,
- store=store,
- timeout=timedelta(seconds=timeout)
- )
- if with_new_group:
- pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout))
- else:
- _pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=timeout))
- pg = c10d._create_process_group_wrapper(
- _pg,
- "unused",
- store,
- self.rank,
- self.world_size,
- timeout=timeout,
- )
- return pg
-
- @requires_nccl()
- @skip_if_lt_x_gpu(2)
- def test_collective_hang(self):
- pg = self._create_wrapper_pg(timeout=2.0)
- self._test_collective_hang(pg)
-
- # NOTE: these tests are separated by debug level instead of combined into
- # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
- # combined after that is resolved.
- @requires_nccl()
- @skip_if_lt_x_gpu(2)
- @with_dist_debug_levels(levels=["DETAIL"])
- def test_collectives_op_mismatch_debug_mode(self):
- pg = self._create_wrapper_pg(with_new_group=True)
- self._test_collectives_op_mismatch(pg, use_cuda=True)
-
- @requires_nccl()
- @skip_if_lt_x_gpu(2)
- @with_dist_debug_levels(levels=["OFF"])
- def test_collectives_op_mismatch(self):
- pg = self._create_wrapper_pg(with_new_group=False)
- self._test_collectives_op_mismatch(pg, use_cuda=True)
-
- @requires_nccl()
- @skip_if_lt_x_gpu(2)
- @with_dist_debug_levels(levels=["DETAIL"])
- def test_collective_shape_mismatch_debug_mode(self):
- pg = self._create_wrapper_pg(with_new_group=True)
- self._test_collective_shape_mismatch(pg, use_cuda=True)
-
- @requires_nccl()
- @skip_if_lt_x_gpu(2)
- @with_dist_debug_levels(levels=["OFF"])
- def test_collective_shape_mismatch(self):
- pg = self._create_wrapper_pg(with_new_group=False)
- self._test_collective_shape_mismatch(pg, use_cuda=True)
-
-
class ProcessGroupNCCLNoGPUTest(TestCase):
MAIN_PROCESS_RANK = 0
diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py
new file mode 100644
index 0000000..aa32a5b
--- /dev/null
+++ b/test/distributed/test_pg_wrapper.py
@@ -0,0 +1,324 @@
+import os
+import sys
+import unittest
+from datetime import timedelta
+
+import torch
+import torch.distributed as c10d
+
+if not c10d.is_available():
+ print("c10d not available, skipping tests", file=sys.stderr)
+ sys.exit(0)
+
+from torch.testing._internal.common_distributed import (
+ MultiProcessTestCase,
+ requires_nccl,
+ requires_gloo,
+ skip_if_lt_x_gpu,
+ with_dist_debug_levels,
+ create_device,
+)
+from torch.testing._internal.common_utils import (
+ run_tests,
+ TEST_WITH_TSAN,
+)
+from test_c10d_common import LOOPBACK
+
+
+class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
+ def setUp(self):
+ super(AbstractProcessGroupWrapperTest, self).setUp()
+ # For Windows platform, Python does not support fork, change it to spawn here.
+ if sys.platform == "win32":
+ self._spawn_processes()
+ else:
+ self._fork_processes()
+
+ def _test_collective_hang(self, wrapper_pg, use_cuda=False):
+ # All ranks besides 1 call allreduce and wrapper_pg should detect a hang
+ # and report an issue with rank 1.
+ faulty_rank = 1
+ if self.rank != faulty_rank:
+ tensor = torch.randn(20, 10)
+ if use_cuda:
+ tensor = tensor.to(self.rank)
+
+ if self.rank == 0:
+ # Rank 0 reports faulty ranks
+ err = f"Ranks {faulty_rank} failed to pass monitoredBarrier"
+ else:
+ err = "Please check rank 0 logs for faulty rank"
+ with self.assertRaisesRegex(RuntimeError, err):
+ wrapper_pg.allreduce([tensor])
+
+ def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
+ tensor = torch.randn(20, 10)
+ if use_cuda:
+ tensor = tensor.to(self.rank)
+ works = []
+ # Run a few successful collectives
+ for _ in range(10):
+ work = wrapper_pg.allreduce([tensor])
+ works.append(work)
+
+ for w in works:
+ w.wait()
+
+ # Simulate mismatch: allreduce vs reduce.
+ with self.assertRaisesRegex(
+ RuntimeError, "Mismatch between collective operation types"
+ ):
+ if self.rank == 0:
+ wrapper_pg.allreduce([tensor])
+ else:
+ wrapper_pg.reduce([tensor])
+
+ # Check additional mismatches
+
+ with self.assertRaisesRegex(
+ RuntimeError, "Mismatch between collective operation types"
+ ):
+ if self.rank == 0:
+ wrapper_pg.reduce([tensor])
+ else:
+ wrapper_pg.barrier()
+
+ with self.assertRaisesRegex(
+ RuntimeError, "Mismatch between collective operation types"
+ ):
+ scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
+ scattered_tensor = torch.empty(4)
+ if self.rank == 0:
+ wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
+ else:
+ wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
+
+ with self.assertRaisesRegex(
+ RuntimeError, "Mismatch between collective operation types"
+ ):
+ if self.rank == 0:
+ wrapper_pg.broadcast(tensor, 0)
+ else:
+ output_tensors = [
+ torch.zeros_like(tensor) for _ in range(self.world_size)
+ ]
+ wrapper_pg.allgather([output_tensors], [tensor])
+
+ def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False):
+ wrapper_pg.barrier()
+ dim = 2 if self.rank == 0 else 10
+ tensor = torch.randn(20, dim)
+ if use_cuda:
+ tensor = tensor.to(self.rank)
+ with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
+ wrapper_pg.allreduce([tensor])
+ # Check errors are raised when dimensionality of shapes is different
+ tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10)
+ if use_cuda:
+ tensor = tensor.to(self.rank)
+ with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
+ wrapper_pg.allreduce([tensor])
+
+ # Check shape errors with scatter
+ input = [
+ torch.tensor(
+ [self.rank] if self.rank == 0 else [self.rank, self.rank],
+ device=self.rank if use_cuda else "cpu",
+ )
+ for _ in range(self.world_size)
+ ]
+ outputs = [
+ torch.tensor(
+ [-1] if self.rank == 0 else [-1, -1],
+ device=self.rank if use_cuda else "cpu",
+ )
+ for _ in range(self.world_size)
+ ]
+ root_rank = 0
+ opts = c10d.ScatterOptions()
+ opts.rootRank = root_rank
+ with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
+ if self.rank == root_rank:
+ wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait()
+ else:
+ wrapper_pg.scatter([outputs[self.rank]], [], opts).wait()
+
+
+@requires_gloo()
+@requires_nccl()
+@unittest.skipIf(
+ TEST_WITH_TSAN,
+ "TSAN is not fork-safe since we're forking in a multi-threaded environment",
+)
+class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest):
+ def setUp(self):
+ self.num_gpus = torch.cuda.device_count()
+ if self.num_gpus < 2:
+ raise unittest.SkipTest("NCCL test requires 2+ GPUs")
+ super(AbstractProcessGroupWrapperTest, self).setUp()
+ self._spawn_processes()
+ # NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
+ # that use NCCL_BLOCKING_WAIT will test it as expected.
+ os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
+
+ @property
+ def world_size(self) -> int:
+ return 2
+
+ def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ c10d.init_process_group(
+ backend="nccl",
+ rank=self.rank,
+ world_size=self.world_size,
+ store=store,
+ timeout=timedelta(seconds=timeout),
+ )
+ if with_new_group:
+ pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout))
+ else:
+ _pg = c10d.ProcessGroupNCCL(
+ store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)
+ )
+ pg = c10d._create_process_group_wrapper(
+ _pg,
+ "unused",
+ store,
+ self.rank,
+ self.world_size,
+ timeout=timeout,
+ )
+ return pg
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ def test_collective_hang(self):
+ pg = self._create_wrapper_pg(timeout=2.0)
+ self._test_collective_hang(pg)
+
+ # NOTE: these tests are separated by debug level instead of combined into
+ # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
+ # combined after that is resolved.
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ @with_dist_debug_levels(levels=["DETAIL"])
+ def test_collectives_op_mismatch_debug_mode(self):
+ pg = self._create_wrapper_pg(with_new_group=True)
+ self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ @with_dist_debug_levels(levels=["OFF"])
+ def test_collectives_op_mismatch(self):
+ pg = self._create_wrapper_pg(with_new_group=False)
+ self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ @with_dist_debug_levels(levels=["DETAIL"])
+ def test_collective_shape_mismatch_debug_mode(self):
+ pg = self._create_wrapper_pg(with_new_group=True)
+ self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ @with_dist_debug_levels(levels=["OFF"])
+ def test_collective_shape_mismatch(self):
+ pg = self._create_wrapper_pg(with_new_group=False)
+ self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+
+@requires_gloo()
+@unittest.skipIf(
+ TEST_WITH_TSAN,
+ "TSAN is not fork-safe since we're forking in a multi-threaded environment",
+)
+class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
+ def setUp(self):
+ super(ProcessGroupGlooWrapperTest, self).setUp()
+
+ def opts(self, threads=2, timeout=10.0):
+ opts = c10d.ProcessGroupGloo._Options()
+ opts._timeout = timeout
+ opts._devices = [create_device(interface=LOOPBACK)]
+ opts._threads = threads
+ return opts
+
+ def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ c10d.init_process_group(
+ backend="gloo", rank=self.rank, world_size=self.world_size, store=store
+ )
+ if with_new_group:
+ pg = c10d.new_group(backend="gloo")
+ else:
+ _pg = c10d.ProcessGroupGloo(
+ store, self.rank, self.world_size, self.opts(timeout=timeout)
+ )
+ pg = c10d._create_process_group_wrapper(
+ _pg,
+ "unused",
+ store,
+ self.rank,
+ self.world_size,
+ timeout=timeout,
+ )
+ return pg
+
+ def test_collective_hang(self):
+ pg = self._create_wrapper_pg(timeout=2.0)
+ self._test_collective_hang(pg)
+
+ # NOTE: these tests are separated by debug level instead of combined into
+ # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
+ # combined after that is resolved.
+ @with_dist_debug_levels(levels=["DETAIL"])
+ def test_collectives_op_mismatch_debug_mode(self):
+ pg = self._create_wrapper_pg(with_new_group=True)
+ self._test_collectives_op_mismatch(pg)
+
+ @with_dist_debug_levels(levels=["OFF"])
+ def test_collectives_op_mismatch(self):
+ pg = self._create_wrapper_pg(with_new_group=False)
+ self._test_collectives_op_mismatch(pg)
+
+ @with_dist_debug_levels(levels=["DETAIL"])
+ def test_collective_shape_mismatch_debug_mode(self):
+ pg = self._create_wrapper_pg(with_new_group=True)
+ self._test_collective_shape_mismatch(pg)
+
+ @with_dist_debug_levels(levels=["OFF"])
+ def test_collective_shape_mismatch(self):
+ pg = self._create_wrapper_pg(with_new_group=False)
+ self._test_collective_shape_mismatch(pg)
+
+ @skip_if_lt_x_gpu(4)
+ @with_dist_debug_levels(levels=["DETAIL"])
+ def test_collectives_op_mismatch_cuda_debug_mode(self):
+ pg = self._create_wrapper_pg(with_new_group=True)
+ self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+ @skip_if_lt_x_gpu(4)
+ @with_dist_debug_levels(levels=["OFF"])
+ def test_collectives_op_mismatch_cuda(self):
+ pg = self._create_wrapper_pg(with_new_group=False)
+ self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+ @skip_if_lt_x_gpu(4)
+ @with_dist_debug_levels(levels=["DETAIL"])
+ def test_collective_shape_mismatch_cuda_debug_mode(self):
+ pg = self._create_wrapper_pg(with_new_group=True)
+ self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+ @skip_if_lt_x_gpu(4)
+ @with_dist_debug_levels(levels=["OFF"])
+ def test_collective_shape_mismatch_cuda(self):
+ pg = self._create_wrapper_pg(with_new_group=False)
+ self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+if __name__ == "__main__":
+ assert (
+ not torch.cuda._initialized
+ ), "test_pg_wrapper must not have initialized CUDA context on main process"
+
+ run_tests()
diff --git a/test/run_test.py b/test/run_test.py
index 38563b2..5670da3 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -49,6 +49,7 @@
'distributed/test_c10d_spawn_gloo',
'distributed/test_c10d_spawn_nccl',
'distributed/test_store',
+ 'distributed/test_pg_wrapper',
'test_cuda',
'test_jit_cuda_fuser',
'test_cuda_primary_ctx',
@@ -311,6 +312,7 @@
'distributed/test_c10d_spawn_gloo',
'distributed/test_c10d_spawn_nccl',
'distributed/test_store',
+ 'distributed/test_pg_wrapper',
'test_quantization',
'test_pruning_op',
'test_determination',
diff --git a/torch/distributed/CONTRIBUTING.md b/torch/distributed/CONTRIBUTING.md
index f621333..0f4428a 100644
--- a/torch/distributed/CONTRIBUTING.md
+++ b/torch/distributed/CONTRIBUTING.md
@@ -81,6 +81,9 @@
# Run the Store tests.
python test/distributed/test_store.py
+# Run Process Group Wrapper tests.
+python test/distributed/test_pg_wrapper.py
+
# Run distributed tests, including tests for Distributed Data Parallel.
python test/run_test.py --verbose -i distributed/test_distributed_fork
python test/run_test.py --verbose -i distributed/test_distributed_spawn