Add support for NCCL alltoall (#44374)

Summary:
In https://github.com/pytorch/pytorch/issues/42514, NCCL `alltoall_single` is already added. This PR adds NCCL `alltoall`.

The difference between `alltoall_single` and `alltoall` is: `alltoall_single`  works on a single tensor and send/receive slices of that tensor, while `alltoall` works on a list of tensor, and send/receive tensors in that list.

cc: ptrblck ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44374

Reviewed By: zhangguanheng66, mrshenli

Differential Revision: D24455427

Pulled By: srinivas212

fbshipit-source-id: 42fdebdd14f8340098e2c34ef645bd40603552b1
diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp
index 5efb77e..0248e81 100644
--- a/torch/csrc/cuda/nccl.cpp
+++ b/torch/csrc/cuda/nccl.cpp
@@ -71,11 +71,8 @@
   }
 }
 
-ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
-  if (!t.is_cuda()) {
-    throw std::runtime_error("Unconvertible NCCL type");
-  }
-  switch (t.scalar_type()) {
+ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
+  switch (type) {
     case at::kFloat:
       return ncclDataType_t::ncclFloat;
     case at::kHalf:
@@ -89,16 +86,25 @@
     case at::kChar:
       return ncclDataType_t::ncclChar;
     case at::kByte:
-      return ncclDataType_t::ncclChar;
+      return ncclDataType_t::ncclUint8;
+    case at::kBool:
+      return ncclDataType_t::ncclUint8;
 #if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301
     case at::kBFloat16:
       return ncclDataType_t::ncclBfloat16;
 #endif
     default:
-      throw std::runtime_error("Unconvertible NCCL type");
+      TORCH_CHECK(false, "Unconvertible NCCL type ", type);
   }
 }
 
+ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
+  if (!t.is_cuda()) {
+    TORCH_CHECK(false, "NCCL only supports CUDA tensors, but got a tensor on ", t.device());
+  }
+  return to_nccl_data_type(t.scalar_type());
+}
+
 ncclRedOp_t to_nccl_red_op(int var) {
   return (ncclRedOp_t)(var);
 }
@@ -625,7 +631,7 @@
 #endif
 }
 
-void all2all(at::Tensor& input,
+void all2all_single_equal_split(at::Tensor& input,
              at::Tensor& output,
              int size,
              ncclComm_t _comm,
@@ -660,6 +666,98 @@
 #endif
 }
 
+void all2all_single_unequal_split(
+    void* sendbuff,
+    const size_t* sendcounts,
+    const size_t* senddispls,
+    void* recvbuff,
+    const size_t* recvcounts,
+    const size_t* recvdispls,
+    size_t size,
+    c10::ScalarType _type,
+    ncclComm_t _comm,
+    at::cuda::CUDAStream& stream) {
+#ifdef USE_NCCL
+#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
+  using namespace torch::cuda::nccl::detail;
+
+  auto type = to_nccl_data_type(_type);
+  auto comm = to_nccl_comm(_comm);
+  int numranks;
+  NCCL_CHECK(ncclCommCount(comm, &numranks));
+  NCCL_CHECK(ncclGroupStart());
+  for (int r = 0; r < numranks; r++) {
+    // NCCL uses 0 byte message for synchronization
+    // Avoid send/recv when message size is zero
+    if (sendcounts[r] != 0) {
+      NCCL_CHECK(ncclSend(
+          ((char*)sendbuff) + senddispls[r] * size,
+          sendcounts[r],
+          type,
+          r,
+          comm,
+          stream));
+    }
+    if (recvcounts[r] != 0) {
+      NCCL_CHECK(ncclRecv(
+          ((char*)recvbuff) + recvdispls[r] * size,
+          recvcounts[r],
+          type,
+          r,
+          comm,
+          stream));
+    }
+  }
+  NCCL_CHECK(ncclGroupEnd());
+#else
+  AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
+#endif
+#else
+  AT_ERROR("PyTorch built without NCCL support");
+#endif
+}
+
+void all2all(std::vector<at::Tensor>& outputTensors,
+             std::vector<at::Tensor>& inputTensors,
+             ncclComm_t _comm,
+             at::cuda::CUDAStream& stream) {
+#ifdef USE_NCCL
+#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
+  using namespace torch::cuda::nccl::detail;
+  auto comm = to_nccl_comm(_comm);
+
+  NCCL_CHECK(ncclGroupStart());
+  for (size_t r = 0; r < outputTensors.size(); r++) {
+    at::Tensor &input = inputTensors[r];
+    at::Tensor &output = outputTensors[r];
+    if (input.numel() != 0) {
+      NCCL_CHECK(ncclSend(
+          input.data_ptr(),
+          input.numel(),
+          to_nccl_data_type(input),
+          r,
+          comm,
+          stream.stream()));
+    }
+    if (output.numel() != 0) {
+      NCCL_CHECK(ncclRecv(
+          output.data_ptr(),
+          output.numel(),
+          to_nccl_data_type(output),
+          r,
+          comm,
+          stream.stream()));
+    }
+  }
+  NCCL_CHECK(ncclGroupEnd());
+#else
+  AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
+#endif
+#else
+  AT_ERROR("PyTorch built without NCCL support");
+#endif
+}
+
 void send(
     const at::Tensor& input,
     ncclComm_t comm,
diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h
index 4cbae2e..1c81171 100644
--- a/torch/csrc/cuda/nccl.h
+++ b/torch/csrc/cuda/nccl.h
@@ -136,13 +136,31 @@
     const stream_list& streams = {},
     const comm_list& user_comms = {});
 
-TORCH_CUDA_API void all2all(
+TORCH_CUDA_API void all2all_single_equal_split(
     at::Tensor& input,
     at::Tensor& output,
     int size,
     ncclComm_t comm,
     at::cuda::CUDAStream& stream);
 
+TORCH_CUDA_API void all2all_single_unequal_split(
+    void* sendbuff,
+    const size_t* sendcounts,
+    const size_t* senddispls,
+    void* recvbuff,
+    const size_t* recvcounts,
+    const size_t* recvdispls,
+    size_t size,
+    c10::ScalarType type,
+    ncclComm_t comm,
+    at::cuda::CUDAStream& stream);
+
+TORCH_CUDA_API void all2all(
+    std::vector<at::Tensor>& outputTensors,
+    std::vector<at::Tensor>& inputTensors,
+    ncclComm_t _comm,
+    at::cuda::CUDAStream& stream);
+
 TORCH_CUDA_API void send(
     const at::Tensor& input,
     ncclComm_t comm,
diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp
index 473fcc0..19bcb67 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.cpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.cpp
@@ -166,49 +166,6 @@
   return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr;
 }
 
-#ifdef ENABLE_NCCL_P2P_SUPPORT
-
-ncclResult_t ncclAlltoallv(
-    void* sendbuff,
-    const size_t* sendcounts,
-    const size_t* senddispls,
-    void* recvbuff,
-    const size_t* recvcounts,
-    const size_t* recvdispls,
-    size_t size,
-    ncclDataType_t type,
-    ncclComm_t comm,
-    cudaStream_t stream) {
-  int numranks;
-  C10D_NCCL_CHECK(ncclCommCount(comm, &numranks));
-  C10D_NCCL_CHECK(ncclGroupStart());
-  for (int r = 0; r < numranks; r++) {
-    // NCCL uses 0 byte message for synchronization
-    // Avoid send/recv when message size is zero
-    if (sendcounts[r] != 0) {
-      C10D_NCCL_CHECK(ncclSend(
-          ((char*)sendbuff) + senddispls[r] * size,
-          sendcounts[r],
-          type,
-          r,
-          comm,
-          stream));
-    }
-    if (recvcounts[r] != 0) {
-      C10D_NCCL_CHECK(ncclRecv(
-          ((char*)recvbuff) + recvdispls[r] * size,
-          recvcounts[r],
-          type,
-          r,
-          comm,
-          stream));
-    }
-  }
-  C10D_NCCL_CHECK(ncclGroupEnd());
-  return ncclSuccess;
-}
-#endif
-
 } // namespace
 
 const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
@@ -1474,7 +1431,7 @@
         // See [Sync Streams].
         c10::cuda::CUDACachingAllocator::recordStream(
               output.storage().data_ptr(), stream);
-        torch::cuda::nccl::all2all(
+        torch::cuda::nccl::all2all_single_equal_split(
               input,
               output,
               this->getSize(),
@@ -1507,7 +1464,7 @@
           // See [Sync Streams].
           c10::cuda::CUDACachingAllocator::recordStream(
               output.storage().data_ptr(), stream);
-          return ncclAlltoallv(
+          torch::cuda::nccl::all2all_single_unequal_split(
               input.data_ptr(),
               send_lengths.data(),
               send_offsets.data(),
@@ -1515,15 +1472,42 @@
               recv_lengths.data(),
               recv_offsets.data(),
               input.element_size(),
-              getNcclDataType(input.scalar_type()),
+              input.scalar_type(),
               comm,
-              stream.stream());
+              stream);
+          return ncclSuccess;
         },
         OpType::ALLTOALL_BASE,
         "nccl:all_to_all");
   }
 }
 
+c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall(
+    std::vector<at::Tensor>& outputTensors,
+    std::vector<at::Tensor>& inputTensors,
+    const AllToAllOptions& /* unused */) {
+  auto device = outputTensors[0].device();
+  for (size_t r = 0; r < outputTensors.size(); r++) {
+    check_gpu_single_tensor(outputTensors[r]);
+    check_gpu_single_tensor(inputTensors[r]);
+    TORCH_CHECK(device == outputTensors[r].device() && device == inputTensors[r].device(),
+      "Tensors must be on the same device")
+  }
+  std::vector<at::Tensor> inputTensor0 = {inputTensors[0]};
+  std::vector<at::Tensor> outputTensor0 = {outputTensors[0]};
+  return collective(
+    inputTensor0,
+    outputTensor0,
+    [&](at::Tensor& /* unused */,
+        at::Tensor& /* unused */,
+        ncclComm_t comm,
+        at::cuda::CUDAStream& stream) {
+      torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream);
+      return ncclSuccess;
+    },
+    OpType::ALLTOALL);
+}
+
 c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::send(
     std::vector<at::Tensor>& tensors,
     int dstRank,
@@ -1572,6 +1556,14 @@
       "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0");
 }
 
+c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall(
+    std::vector<at::Tensor>& /* unused */,
+    std::vector<at::Tensor>& /* unused */,
+    const AllToAllOptions& /* unused */) {
+  throw std::runtime_error(
+      "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0");
+}
+
 c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::send(
     std::vector<at::Tensor>& /* unused */,
     int /* unused */,
@@ -1603,13 +1595,6 @@
   --ncclActiveGroupCounter_;
 }
 
-c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall(
-    std::vector<at::Tensor>& /* unused */,
-    std::vector<at::Tensor>& /* unused */,
-    const AllToAllOptions& /* unused */) {
-  throw std::runtime_error("ProcessGroupNCCL does not support alltoall");
-}
-
 c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::gather(
     std::vector<std::vector<at::Tensor>>& /* unused */,
     std::vector<at::Tensor>& /* unused */,
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index 3b9882d..30966cf 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -2191,7 +2191,14 @@
                 self.assertEqual(out_tensor, expected_tensor)
             self._barrier()
 
-        def _test_all_to_all_helper(self, group, group_id, rank):
+        def _test_all_to_all_helper(
+            self,
+            group,
+            group_id,
+            rank,
+            cuda=False,
+            rank_to_GPU=None,
+        ):
             if group_id is not None:
                 size = len(group)
                 in_splits = [i + 1 for i in group]
@@ -2200,6 +2207,10 @@
                 ]
                 out_tensors = [torch.ones([(rank + 1), size]) for _ in group]
                 expected_tensors = [torch.ones([rank + 1, size]) * i for i in group]
+                if cuda:
+                    in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
+                    expected_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors]
+                    out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
                 dist.all_to_all(out_tensors, in_tensors, group=group_id)
                 for t1, t2 in zip(out_tensors, expected_tensors):
                     self.assertEqual(t1, t2)
@@ -2212,7 +2223,6 @@
             group, group_id, rank = self._init_global_test()
             self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
 
-        @unittest.skip("NCCL A2A is not enabled for OSS builds")
         @unittest.skipIf(
             BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
         )
@@ -2236,7 +2246,6 @@
             group, group_id, rank = self._init_global_test()
             self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
 
-        @unittest.skip("NCCL A2A is not enabled for OSS builds")
         @unittest.skipIf(
             BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
         )
@@ -2258,6 +2267,13 @@
             group, group_id, rank = self._init_global_test()
             self._test_all_to_all_helper(group, group_id, rank)
 
+        @unittest.skipIf(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
+        @skip_if_rocm
+        def test_all_to_all_cuda(self):
+            group, group_id, rank = self._init_global_test()
+            rank_to_GPU = self._init_multigpu_helper()
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
         @unittest.skipIf(
             BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
         )
@@ -2266,7 +2282,6 @@
             group, group_id, rank = self._init_group_test()
             self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
 
-        @unittest.skip("NCCL A2A is not enabled for OSS builds")
         @unittest.skipIf(
             BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
         )
@@ -2292,7 +2307,6 @@
             group, group_id, rank = self._init_group_test()
             self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
 
-        @unittest.skip("NCCL A2A is not enabled for OSS builds")
         @unittest.skipIf(
             BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
         )
@@ -2317,13 +2331,27 @@
             self._test_all_to_all_helper(group, group_id, rank)
 
         @unittest.skipIf(
+            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
+        )
+        @skip_if_small_worldsize
+        @skip_if_rocm
+        def test_all_to_all_group_cuda(self):
+            group, group_id, rank = self._init_group_test()
+            rank_to_GPU = self._init_multigpu_helper()
+            self._test_all_to_all_helper(
+                group,
+                group_id,
+                rank,
+                True,
+                rank_to_GPU)
+
+        @unittest.skipIf(
             BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
         )
         def test_all_to_all_single_equal_split_full_group(self):
             group, group_id, rank = self._init_full_group_test()
             self._test_all_to_all_single_equal_split_helper(group, group_id, rank)
 
-        @unittest.skip("NCCL A2A is not enabled for OSS builds")
         @unittest.skipIf(
             BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
         )
@@ -2347,7 +2375,6 @@
             group, group_id, rank = self._init_full_group_test()
             self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)
 
-        @unittest.skip("NCCL A2A is not enabled for OSS builds")
         @unittest.skipIf(
             BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
         )
@@ -2369,6 +2396,13 @@
             group, group_id, rank = self._init_full_group_test()
             self._test_all_to_all_helper(group, group_id, rank)
 
+        @unittest.skipIf(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
+        @skip_if_rocm
+        def test_all_to_all_full_group_cuda(self):
+            group, group_id, rank = self._init_full_group_test()
+            rank_to_GPU = self._init_multigpu_helper()
+            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)
+
         # BARRIER
         def _test_barrier_helper(
                 self, group, group_id, rank, cuda=False, rank_to_GPU=None):