add test_c10d_spawn_ucc.py (#86508)

Initial PR to create UCC equivalent of https://github.com/pytorch/pytorch/blob/master/test/distributed/test_c10d_spawn_gloo.py and
https://github.com/pytorch/pytorch/blob/master/test/distributed/test_c10d_spawn_nccl.py. Currently only added common ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86508
Approved by: https://github.com/kwen2501
diff --git a/test/distributed/test_c10d_spawn_ucc.py b/test/distributed/test_c10d_spawn_ucc.py
new file mode 100644
index 0000000..eabd7e1
--- /dev/null
+++ b/test/distributed/test_c10d_spawn_ucc.py
@@ -0,0 +1,110 @@
+# Owner(s): ["oncall: distributed"]
+
+import sys
+import test_c10d_spawn
+import torch
+import torch.distributed as c10d
+from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
+from torch.testing._internal.common_cuda import TEST_MULTIGPU
+from torch.testing._internal.common_distributed import (
+    requires_ucc,
+    skip_if_lt_x_gpu,
+)
+from torch.testing._internal.common_utils import (
+    TestCase,
+    run_tests,
+    sandcastle_skip,
+    sandcastle_skip_if,
+    TEST_WITH_DEV_DBG_ASAN,
+)
+
+NO_UCC = not hasattr(c10d, "ProcessGroupUCC")
+
+# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
+if sys.version_info < (3, 9):
+
+    class ProcessGroupShareTensorTest(
+        test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
+    ):
+        @classmethod
+        def _init_pg_ucc(cls, rank, filename, world_size):
+            store = c10d.FileStore(filename, world_size)
+            return c10d.ProcessGroupUCC(store, rank, world_size)
+
+        @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
+        @sandcastle_skip_if(NO_UCC, "UCC needed")
+        def test_shared_broadcast_ucc(self):
+            self._test_multiprocess(
+                ProcessGroupShareTensorTest._test_broadcast_process,
+                [torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
+                ProcessGroupShareTensorTest._init_pg_ucc,
+                1,
+            )
+
+        @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
+        @sandcastle_skip_if(NO_UCC, "UCC needed")
+        def test_shared_allreduce_ucc(self):
+            self._test_multiprocess(
+                ProcessGroupShareTensorTest._test_allreduce_process,
+                [torch.ones(2, 2).to(i) for i in range(self.world_size)],
+                ProcessGroupShareTensorTest._init_pg_ucc,
+                1,
+            )
+
+        @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
+        @sandcastle_skip_if(NO_UCC, "UCC needed")
+        def test_shared_allgather_ucc(self):
+            self._test_multiprocess(
+                ProcessGroupShareTensorTest._test_allgather_process,
+                [torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
+                ProcessGroupShareTensorTest._init_pg_ucc,
+                self.world_size,
+            )
+
+
+# Skip dev-asan as torch + multiprocessing spawn have known issues
+if not TEST_WITH_DEV_DBG_ASAN:
+
+    class TestDistributedNNFunctionsUcc(TestDistributedNNFunctions):
+        # Test Common Ops First.
+        @requires_ucc()
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(
+            not _torch_dist_nn_available, "torch.distributed.nn is not available"
+        )
+        def test_broadcast(self):
+            self._test_broadcast("ucc")
+
+        @requires_ucc()
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
+        def test_reduce(self):
+            self._test_reduce("ucc")
+
+        @requires_ucc()
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
+        def test_allreduce(self):
+            self._test_allreduce("ucc")
+
+        @requires_ucc()
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
+        @sandcastle_skip("runs into illegal memory access on first assertEqual check when run locally")
+        def test_all_gather(self):
+            self._test_all_gather("ucc")
+
+        @requires_ucc()
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
+        def test_all_to_all(self):
+            self._test_all_to_all("ucc")
+
+        @requires_ucc()
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
+        def test_all_to_all_single(self):
+            self._test_all_to_all_single("ucc")
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/run_test.py b/test/run_test.py
index 8a25a2e..6bf98a0 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -785,6 +785,7 @@
     "distributed/test_c10d_common": get_run_test_with_subprocess_fn(),
     "distributed/test_c10d_spawn_gloo": get_run_test_with_subprocess_fn(),
     "distributed/test_c10d_spawn_nccl": get_run_test_with_subprocess_fn(),
+    "distributed/test_c10d_spawn_ucc": get_run_test_with_subprocess_fn(),
     "distributed/test_store": get_run_test_with_subprocess_fn(),
     "distributed/test_pg_wrapper": get_run_test_with_subprocess_fn(),
     "distributed/rpc/test_faulty_agent": get_run_test_with_subprocess_fn(),
diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py
index 883a48a..9dcb71a 100644
--- a/torch/testing/_internal/common_distributed.py
+++ b/torch/testing/_internal/common_distributed.py
@@ -304,6 +304,11 @@
         "c10d was not compiled with the NCCL backend",
     )
 
+def requires_ucc():
+    return sandcastle_skip_if(
+        not c10d.is_ucc_available(),
+        "c10d was not compiled with the UCC backend",
+    )
 
 def requires_mpi():
     return sandcastle_skip_if(