[c10d] Move pg wrapper tests to their own file. (#59840)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59840

moving these tests to their own standalone file. No meaningful code changes.
ghstack-source-id: 131359162

Test Plan: CI

Reviewed By: cbalioglu

Differential Revision: D29012664

fbshipit-source-id: 348870016509a6ed7e69240fa82bccef4a12d674
diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh
index daa0fb8..7ed43a9 100755
--- a/.jenkins/pytorch/multigpu-test.sh
+++ b/.jenkins/pytorch/multigpu-test.sh
@@ -26,6 +26,7 @@
 time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
 time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
 time python test/run_test.py --verbose -i distributed/test_store
+time python test/run_test.py --verbose -i distributed/test_pg_wrapper
 time python test/run_test.py --verbose -i distributed/rpc/cuda/test_process_group_agent
 time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
 assert_git_not_dirty
diff --git a/.jenkins/pytorch/win-test-helpers/test_distributed.bat b/.jenkins/pytorch/win-test-helpers/test_distributed.bat
index a50c153..53ebee8 100644
--- a/.jenkins/pytorch/win-test-helpers/test_distributed.bat
+++ b/.jenkins/pytorch/win-test-helpers/test_distributed.bat
@@ -22,3 +22,6 @@
 
 %1\python.exe test/run_test.py --verbose -i distributed/test_store
 if %errorlevel% neq 0 ( exit /b %errorlevel% )
+
+%1\python.exe test/run_test.py --verbose -i distributed/test_pg_wrapper
+if %errorlevel% neq 0 ( exit /b %errorlevel% )
diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py
index 4d2dc0e..a25ef9a 100644
--- a/test/distributed/test_c10d_common.py
+++ b/test/distributed/test_c10d_common.py
@@ -233,125 +233,6 @@
         return F.softmax(self.embedding(x), dim=1)
 
 
-class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
-    def setUp(self):
-        super(AbstractProcessGroupWrapperTest, self).setUp()
-        # For Windows platform, Python does not support fork, change it to spawn here.
-        if sys.platform == "win32":
-            self._spawn_processes()
-        else:
-            self._fork_processes()
-
-    def _test_collective_hang(self, wrapper_pg, use_cuda=False):
-        # All ranks besides 1 call allreduce and wrapper_pg should detect a hang
-        # and report an issue with rank 1.
-        faulty_rank = 1
-        if self.rank != faulty_rank:
-            tensor = torch.randn(20, 10)
-            if use_cuda:
-                tensor = tensor.to(self.rank)
-
-            if self.rank == 0:
-                # Rank 0 reports faulty ranks
-                err = f"Ranks {faulty_rank} failed to pass monitoredBarrier"
-            else:
-                err = "Please check rank 0 logs for faulty rank"
-            with self.assertRaisesRegex(RuntimeError, err):
-                wrapper_pg.allreduce([tensor])
-
-    def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
-        tensor = torch.randn(20, 10)
-        if use_cuda:
-            tensor = tensor.to(self.rank)
-        works = []
-        # Run a few successful collectives
-        for _ in range(10):
-            work = wrapper_pg.allreduce([tensor])
-            works.append(work)
-
-        for w in works:
-            w.wait()
-
-        # Simulate mismatch: allreduce vs reduce.
-        with self.assertRaisesRegex(
-            RuntimeError, "Mismatch between collective operation types"
-        ):
-            if self.rank == 0:
-                wrapper_pg.allreduce([tensor])
-            else:
-                wrapper_pg.reduce([tensor])
-
-        # Check additional mismatches
-
-        with self.assertRaisesRegex(
-            RuntimeError, "Mismatch between collective operation types"
-        ):
-            if self.rank == 0:
-                wrapper_pg.reduce([tensor])
-            else:
-                wrapper_pg.barrier()
-
-        with self.assertRaisesRegex(
-            RuntimeError, "Mismatch between collective operation types"
-        ):
-            scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
-            scattered_tensor = torch.empty(4)
-            if self.rank == 0:
-                wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
-            else:
-                wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
-
-        with self.assertRaisesRegex(
-            RuntimeError, "Mismatch between collective operation types"
-        ):
-            if self.rank == 0:
-                wrapper_pg.broadcast(tensor, 0)
-            else:
-                output_tensors = [
-                    torch.zeros_like(tensor) for _ in range(self.world_size)
-                ]
-                wrapper_pg.allgather([output_tensors], [tensor])
-
-    def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False):
-        wrapper_pg.barrier()
-        dim = 2 if self.rank == 0 else 10
-        tensor = torch.randn(20, dim)
-        if use_cuda:
-            tensor = tensor.to(self.rank)
-        with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
-            wrapper_pg.allreduce([tensor])
-        # Check errors are raised when dimensionality of shapes is different
-        tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10)
-        if use_cuda:
-            tensor = tensor.to(self.rank)
-        with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
-            wrapper_pg.allreduce([tensor])
-
-        # Check shape errors with scatter
-        input = [
-            torch.tensor(
-                [self.rank] if self.rank == 0 else [self.rank, self.rank],
-                device=self.rank if use_cuda else "cpu",
-            )
-            for _ in range(self.world_size)
-        ]
-        outputs = [
-            torch.tensor(
-                [-1] if self.rank == 0 else [-1, -1],
-                device=self.rank if use_cuda else "cpu",
-            )
-            for _ in range(self.world_size)
-        ]
-        root_rank = 0
-        opts = c10d.ScatterOptions()
-        opts.rootRank = root_rank
-        with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
-            if self.rank == root_rank:
-                wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait()
-            else:
-                wrapper_pg.scatter([outputs[self.rank]], [], opts).wait()
-
-
 class AbstractDistributedDataParallelTest(object):
     def tearDown(self):
         # DistributedDataParallel test doesn't seem to call FileStore destructor
diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py
index 8ce6443..a4c2b85 100644
--- a/test/distributed/test_c10d_gloo.py
+++ b/test/distributed/test_c10d_gloo.py
@@ -29,7 +29,6 @@
     simple_sparse_reduce_tests,
     skip_if_win32,
     create_device,
-    with_dist_debug_levels,
     verify_ddp_error_logged,
 )
 from torch.testing._internal.common_utils import (
@@ -45,7 +44,6 @@
     Task,
     ModuleForDdpCommHook,
     SparseGradientModule,
-    AbstractProcessGroupWrapperTest,
 )
 
 
@@ -208,92 +206,6 @@
     TEST_WITH_TSAN,
     "TSAN is not fork-safe since we're forking in a multi-threaded environment",
 )
-class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
-    def setUp(self):
-        super(ProcessGroupGlooWrapperTest, self).setUp()
-
-    def opts(self, threads=2, timeout=10.0):
-        opts = c10d.ProcessGroupGloo._Options()
-        opts._timeout = timeout
-        opts._devices = [create_device(interface=LOOPBACK)]
-        opts._threads = threads
-        return opts
-
-    def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
-        store = c10d.FileStore(self.file_name, self.world_size)
-        c10d.init_process_group(
-            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
-        )
-        if with_new_group:
-            pg = c10d.new_group(backend="gloo")
-        else:
-            _pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(timeout=timeout))
-            pg = c10d._create_process_group_wrapper(
-                _pg,
-                "unused",
-                store,
-                self.rank,
-                self.world_size,
-                timeout=timeout,
-            )
-        return pg
-
-    def test_collective_hang(self):
-        pg = self._create_wrapper_pg(timeout=2.0)
-        self._test_collective_hang(pg)
-
-    # NOTE: these tests are separated by debug level instead of combined into
-    # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
-    # combined after that is resolved.
-    @with_dist_debug_levels(levels=["DETAIL"])
-    def test_collectives_op_mismatch_debug_mode(self):
-        pg = self._create_wrapper_pg(with_new_group=True)
-        self._test_collectives_op_mismatch(pg)
-
-    @with_dist_debug_levels(levels=["OFF"])
-    def test_collectives_op_mismatch(self):
-        pg = self._create_wrapper_pg(with_new_group=False)
-        self._test_collectives_op_mismatch(pg)
-
-    @with_dist_debug_levels(levels=["DETAIL"])
-    def test_collective_shape_mismatch_debug_mode(self):
-        pg = self._create_wrapper_pg(with_new_group=True)
-        self._test_collective_shape_mismatch(pg)
-
-    @with_dist_debug_levels(levels=["OFF"])
-    def test_collective_shape_mismatch(self):
-        pg = self._create_wrapper_pg(with_new_group=False)
-        self._test_collective_shape_mismatch(pg)
-
-    @skip_if_lt_x_gpu(4)
-    @with_dist_debug_levels(levels=["DETAIL"])
-    def test_collectives_op_mismatch_cuda_debug_mode(self):
-        pg = self._create_wrapper_pg(with_new_group=True)
-        self._test_collectives_op_mismatch(pg, use_cuda=True)
-
-    @skip_if_lt_x_gpu(4)
-    @with_dist_debug_levels(levels=["OFF"])
-    def test_collectives_op_mismatch_cuda(self):
-        pg = self._create_wrapper_pg(with_new_group=False)
-        self._test_collectives_op_mismatch(pg, use_cuda=True)
-
-    @skip_if_lt_x_gpu(4)
-    @with_dist_debug_levels(levels=["DETAIL"])
-    def test_collective_shape_mismatch_cuda_debug_mode(self):
-        pg = self._create_wrapper_pg(with_new_group=True)
-        self._test_collective_shape_mismatch(pg, use_cuda=True)
-
-    @skip_if_lt_x_gpu(4)
-    @with_dist_debug_levels(levels=["OFF"])
-    def test_collective_shape_mismatch_cuda(self):
-        pg = self._create_wrapper_pg(with_new_group=False)
-        self._test_collective_shape_mismatch(pg, use_cuda=True)
-
-@requires_gloo()
-@unittest.skipIf(
-    TEST_WITH_TSAN,
-    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
-)
 class ProcessGroupGlooTest(MultiProcessTestCase):
     def setUp(self):
         super(ProcessGroupGlooTest, self).setUp()
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index 244b506..5583cbb 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -30,7 +30,6 @@
 from torch.utils.checkpoint import checkpoint
 from torch.testing._internal.common_distributed import (
     MultiProcessTestCase,
-    requires_gloo,
     requires_nccl,
     requires_nccl_version,
     skip_if_lt_x_gpu,
@@ -46,7 +45,7 @@
     TEST_WITH_TSAN,
 )
 import test_c10d_common
-from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook, AbstractProcessGroupWrapperTest
+from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook
 
 
 class RendezvousEnvTest(TestCase):
@@ -159,88 +158,6 @@
             raise unittest.SkipTest("No GPUs available, skipping test")
         self._test_default_store_timeout("nccl")
 
-@requires_gloo()
-@requires_nccl()
-@unittest.skipIf(
-    TEST_WITH_TSAN,
-    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
-)
-class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest):
-    def setUp(self):
-        self.num_gpus = torch.cuda.device_count()
-        if self.num_gpus < 2:
-            raise unittest.SkipTest("NCCL test requires 2+ GPUs")
-        super(AbstractProcessGroupWrapperTest, self).setUp()
-        self._spawn_processes()
-        # NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
-        # that use NCCL_BLOCKING_WAIT will test it as expected.
-        os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
-
-    @property
-    def world_size(self) -> int:
-        return 2
-
-    def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
-        store = c10d.FileStore(self.file_name, self.world_size)
-        c10d.init_process_group(
-            backend="nccl",
-            rank=self.rank,
-            world_size=self.world_size,
-            store=store,
-            timeout=timedelta(seconds=timeout)
-        )
-        if with_new_group:
-            pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout))
-        else:
-            _pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=timeout))
-            pg = c10d._create_process_group_wrapper(
-                _pg,
-                "unused",
-                store,
-                self.rank,
-                self.world_size,
-                timeout=timeout,
-            )
-        return pg
-
-    @requires_nccl()
-    @skip_if_lt_x_gpu(2)
-    def test_collective_hang(self):
-        pg = self._create_wrapper_pg(timeout=2.0)
-        self._test_collective_hang(pg)
-
-    # NOTE: these tests are separated by debug level instead of combined into
-    # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
-    # combined after that is resolved.
-    @requires_nccl()
-    @skip_if_lt_x_gpu(2)
-    @with_dist_debug_levels(levels=["DETAIL"])
-    def test_collectives_op_mismatch_debug_mode(self):
-        pg = self._create_wrapper_pg(with_new_group=True)
-        self._test_collectives_op_mismatch(pg, use_cuda=True)
-
-    @requires_nccl()
-    @skip_if_lt_x_gpu(2)
-    @with_dist_debug_levels(levels=["OFF"])
-    def test_collectives_op_mismatch(self):
-        pg = self._create_wrapper_pg(with_new_group=False)
-        self._test_collectives_op_mismatch(pg, use_cuda=True)
-
-    @requires_nccl()
-    @skip_if_lt_x_gpu(2)
-    @with_dist_debug_levels(levels=["DETAIL"])
-    def test_collective_shape_mismatch_debug_mode(self):
-        pg = self._create_wrapper_pg(with_new_group=True)
-        self._test_collective_shape_mismatch(pg, use_cuda=True)
-
-    @requires_nccl()
-    @skip_if_lt_x_gpu(2)
-    @with_dist_debug_levels(levels=["OFF"])
-    def test_collective_shape_mismatch(self):
-        pg = self._create_wrapper_pg(with_new_group=False)
-        self._test_collective_shape_mismatch(pg, use_cuda=True)
-
-
 class ProcessGroupNCCLNoGPUTest(TestCase):
     MAIN_PROCESS_RANK = 0
 
diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py
new file mode 100644
index 0000000..aa32a5b
--- /dev/null
+++ b/test/distributed/test_pg_wrapper.py
@@ -0,0 +1,324 @@
+import os
+import sys
+import unittest
+from datetime import timedelta
+
+import torch
+import torch.distributed as c10d
+
+if not c10d.is_available():
+    print("c10d not available, skipping tests", file=sys.stderr)
+    sys.exit(0)
+
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    requires_nccl,
+    requires_gloo,
+    skip_if_lt_x_gpu,
+    with_dist_debug_levels,
+    create_device,
+)
+from torch.testing._internal.common_utils import (
+    run_tests,
+    TEST_WITH_TSAN,
+)
+from test_c10d_common import LOOPBACK
+
+
+class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
+    def setUp(self):
+        super(AbstractProcessGroupWrapperTest, self).setUp()
+        # For Windows platform, Python does not support fork, change it to spawn here.
+        if sys.platform == "win32":
+            self._spawn_processes()
+        else:
+            self._fork_processes()
+
+    def _test_collective_hang(self, wrapper_pg, use_cuda=False):
+        # All ranks besides 1 call allreduce and wrapper_pg should detect a hang
+        # and report an issue with rank 1.
+        faulty_rank = 1
+        if self.rank != faulty_rank:
+            tensor = torch.randn(20, 10)
+            if use_cuda:
+                tensor = tensor.to(self.rank)
+
+            if self.rank == 0:
+                # Rank 0 reports faulty ranks
+                err = f"Ranks {faulty_rank} failed to pass monitoredBarrier"
+            else:
+                err = "Please check rank 0 logs for faulty rank"
+            with self.assertRaisesRegex(RuntimeError, err):
+                wrapper_pg.allreduce([tensor])
+
+    def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
+        tensor = torch.randn(20, 10)
+        if use_cuda:
+            tensor = tensor.to(self.rank)
+        works = []
+        # Run a few successful collectives
+        for _ in range(10):
+            work = wrapper_pg.allreduce([tensor])
+            works.append(work)
+
+        for w in works:
+            w.wait()
+
+        # Simulate mismatch: allreduce vs reduce.
+        with self.assertRaisesRegex(
+            RuntimeError, "Mismatch between collective operation types"
+        ):
+            if self.rank == 0:
+                wrapper_pg.allreduce([tensor])
+            else:
+                wrapper_pg.reduce([tensor])
+
+        # Check additional mismatches
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Mismatch between collective operation types"
+        ):
+            if self.rank == 0:
+                wrapper_pg.reduce([tensor])
+            else:
+                wrapper_pg.barrier()
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Mismatch between collective operation types"
+        ):
+            scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
+            scattered_tensor = torch.empty(4)
+            if self.rank == 0:
+                wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
+            else:
+                wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
+
+        with self.assertRaisesRegex(
+            RuntimeError, "Mismatch between collective operation types"
+        ):
+            if self.rank == 0:
+                wrapper_pg.broadcast(tensor, 0)
+            else:
+                output_tensors = [
+                    torch.zeros_like(tensor) for _ in range(self.world_size)
+                ]
+                wrapper_pg.allgather([output_tensors], [tensor])
+
+    def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False):
+        wrapper_pg.barrier()
+        dim = 2 if self.rank == 0 else 10
+        tensor = torch.randn(20, dim)
+        if use_cuda:
+            tensor = tensor.to(self.rank)
+        with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
+            wrapper_pg.allreduce([tensor])
+        # Check errors are raised when dimensionality of shapes is different
+        tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10)
+        if use_cuda:
+            tensor = tensor.to(self.rank)
+        with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
+            wrapper_pg.allreduce([tensor])
+
+        # Check shape errors with scatter
+        input = [
+            torch.tensor(
+                [self.rank] if self.rank == 0 else [self.rank, self.rank],
+                device=self.rank if use_cuda else "cpu",
+            )
+            for _ in range(self.world_size)
+        ]
+        outputs = [
+            torch.tensor(
+                [-1] if self.rank == 0 else [-1, -1],
+                device=self.rank if use_cuda else "cpu",
+            )
+            for _ in range(self.world_size)
+        ]
+        root_rank = 0
+        opts = c10d.ScatterOptions()
+        opts.rootRank = root_rank
+        with self.assertRaisesRegex(RuntimeError, "Error when verifying shape tensors"):
+            if self.rank == root_rank:
+                wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait()
+            else:
+                wrapper_pg.scatter([outputs[self.rank]], [], opts).wait()
+
+
+@requires_gloo()
+@requires_nccl()
+@unittest.skipIf(
+    TEST_WITH_TSAN,
+    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
+)
+class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest):
+    def setUp(self):
+        self.num_gpus = torch.cuda.device_count()
+        if self.num_gpus < 2:
+            raise unittest.SkipTest("NCCL test requires 2+ GPUs")
+        super(AbstractProcessGroupWrapperTest, self).setUp()
+        self._spawn_processes()
+        # NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
+        # that use NCCL_BLOCKING_WAIT will test it as expected.
+        os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
+
+    @property
+    def world_size(self) -> int:
+        return 2
+
+    def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        c10d.init_process_group(
+            backend="nccl",
+            rank=self.rank,
+            world_size=self.world_size,
+            store=store,
+            timeout=timedelta(seconds=timeout),
+        )
+        if with_new_group:
+            pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout))
+        else:
+            _pg = c10d.ProcessGroupNCCL(
+                store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)
+            )
+            pg = c10d._create_process_group_wrapper(
+                _pg,
+                "unused",
+                store,
+                self.rank,
+                self.world_size,
+                timeout=timeout,
+            )
+        return pg
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    def test_collective_hang(self):
+        pg = self._create_wrapper_pg(timeout=2.0)
+        self._test_collective_hang(pg)
+
+    # NOTE: these tests are separated by debug level instead of combined into
+    # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
+    # combined after that is resolved.
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    @with_dist_debug_levels(levels=["DETAIL"])
+    def test_collectives_op_mismatch_debug_mode(self):
+        pg = self._create_wrapper_pg(with_new_group=True)
+        self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    @with_dist_debug_levels(levels=["OFF"])
+    def test_collectives_op_mismatch(self):
+        pg = self._create_wrapper_pg(with_new_group=False)
+        self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    @with_dist_debug_levels(levels=["DETAIL"])
+    def test_collective_shape_mismatch_debug_mode(self):
+        pg = self._create_wrapper_pg(with_new_group=True)
+        self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+    @requires_nccl()
+    @skip_if_lt_x_gpu(2)
+    @with_dist_debug_levels(levels=["OFF"])
+    def test_collective_shape_mismatch(self):
+        pg = self._create_wrapper_pg(with_new_group=False)
+        self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+
+@requires_gloo()
+@unittest.skipIf(
+    TEST_WITH_TSAN,
+    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
+)
+class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
+    def setUp(self):
+        super(ProcessGroupGlooWrapperTest, self).setUp()
+
+    def opts(self, threads=2, timeout=10.0):
+        opts = c10d.ProcessGroupGloo._Options()
+        opts._timeout = timeout
+        opts._devices = [create_device(interface=LOOPBACK)]
+        opts._threads = threads
+        return opts
+
+    def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
+        store = c10d.FileStore(self.file_name, self.world_size)
+        c10d.init_process_group(
+            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
+        )
+        if with_new_group:
+            pg = c10d.new_group(backend="gloo")
+        else:
+            _pg = c10d.ProcessGroupGloo(
+                store, self.rank, self.world_size, self.opts(timeout=timeout)
+            )
+            pg = c10d._create_process_group_wrapper(
+                _pg,
+                "unused",
+                store,
+                self.rank,
+                self.world_size,
+                timeout=timeout,
+            )
+        return pg
+
+    def test_collective_hang(self):
+        pg = self._create_wrapper_pg(timeout=2.0)
+        self._test_collective_hang(pg)
+
+    # NOTE: these tests are separated by debug level instead of combined into
+    # one due to https://github.com/pytorch/pytorch/issues/55967, they can be
+    # combined after that is resolved.
+    @with_dist_debug_levels(levels=["DETAIL"])
+    def test_collectives_op_mismatch_debug_mode(self):
+        pg = self._create_wrapper_pg(with_new_group=True)
+        self._test_collectives_op_mismatch(pg)
+
+    @with_dist_debug_levels(levels=["OFF"])
+    def test_collectives_op_mismatch(self):
+        pg = self._create_wrapper_pg(with_new_group=False)
+        self._test_collectives_op_mismatch(pg)
+
+    @with_dist_debug_levels(levels=["DETAIL"])
+    def test_collective_shape_mismatch_debug_mode(self):
+        pg = self._create_wrapper_pg(with_new_group=True)
+        self._test_collective_shape_mismatch(pg)
+
+    @with_dist_debug_levels(levels=["OFF"])
+    def test_collective_shape_mismatch(self):
+        pg = self._create_wrapper_pg(with_new_group=False)
+        self._test_collective_shape_mismatch(pg)
+
+    @skip_if_lt_x_gpu(4)
+    @with_dist_debug_levels(levels=["DETAIL"])
+    def test_collectives_op_mismatch_cuda_debug_mode(self):
+        pg = self._create_wrapper_pg(with_new_group=True)
+        self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+    @skip_if_lt_x_gpu(4)
+    @with_dist_debug_levels(levels=["OFF"])
+    def test_collectives_op_mismatch_cuda(self):
+        pg = self._create_wrapper_pg(with_new_group=False)
+        self._test_collectives_op_mismatch(pg, use_cuda=True)
+
+    @skip_if_lt_x_gpu(4)
+    @with_dist_debug_levels(levels=["DETAIL"])
+    def test_collective_shape_mismatch_cuda_debug_mode(self):
+        pg = self._create_wrapper_pg(with_new_group=True)
+        self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+    @skip_if_lt_x_gpu(4)
+    @with_dist_debug_levels(levels=["OFF"])
+    def test_collective_shape_mismatch_cuda(self):
+        pg = self._create_wrapper_pg(with_new_group=False)
+        self._test_collective_shape_mismatch(pg, use_cuda=True)
+
+if __name__ == "__main__":
+    assert (
+        not torch.cuda._initialized
+    ), "test_pg_wrapper must not have initialized CUDA context on main process"
+
+    run_tests()
diff --git a/test/run_test.py b/test/run_test.py
index 38563b2..5670da3 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -49,6 +49,7 @@
     'distributed/test_c10d_spawn_gloo',
     'distributed/test_c10d_spawn_nccl',
     'distributed/test_store',
+    'distributed/test_pg_wrapper',
     'test_cuda',
     'test_jit_cuda_fuser',
     'test_cuda_primary_ctx',
@@ -311,6 +312,7 @@
     'distributed/test_c10d_spawn_gloo',
     'distributed/test_c10d_spawn_nccl',
     'distributed/test_store',
+    'distributed/test_pg_wrapper',
     'test_quantization',
     'test_pruning_op',
     'test_determination',
diff --git a/torch/distributed/CONTRIBUTING.md b/torch/distributed/CONTRIBUTING.md
index f621333..0f4428a 100644
--- a/torch/distributed/CONTRIBUTING.md
+++ b/torch/distributed/CONTRIBUTING.md
@@ -81,6 +81,9 @@
 # Run the Store tests.
 python test/distributed/test_store.py
 
+# Run Process Group Wrapper tests.
+python test/distributed/test_pg_wrapper.py
+
 # Run distributed tests, including tests for Distributed Data Parallel.
 python test/run_test.py --verbose -i distributed/test_distributed_fork
 python test/run_test.py --verbose -i distributed/test_distributed_spawn