Move nccl scatter and gather to C++ (#9117)

Summary:
As I try to replicate DP in C++, I need to move some functions into C++ from Python. This PR ports the scatter and gather primitives from Python in torch/cuda/comm.py to C++ in torch/csrc/cuda/comm.cpp. The basic infrastructure was already there, since apaszke had rewritten broadcast in C++ already.

I'm not very familiar with this code, so let me know if I'm doing something wrong. I largely just literally translated the code.

I don't know how "public" `torch.cuda.comm` is, but I feel like the `destination_index` parameter for `gather` should be changed from -1 indicating CPU to `None` indicating CPU, and `-1` indicating the default CUDA device. That would make the code clearer IMO.

apaszke colesbury teng-li pietern
Closes https://github.com/pytorch/pytorch/pull/9117

Differential Revision: D8721729

Pulled By: goldsborough

fbshipit-source-id: 1844a488079d21fa209b32e2c73e48632cbe9e68
diff --git a/aten/src/THC/THCGeneral.cpp b/aten/src/THC/THCGeneral.cpp
index 6f4fdf0..2f2258b 100644
--- a/aten/src/THC/THCGeneral.cpp
+++ b/aten/src/THC/THCGeneral.cpp
@@ -431,7 +431,7 @@
   return res->sparseHandles[handle - 1];
 }
 
-static THCStream* THCState_getStreamOnDevice(THCState* state, int device)
+THCStream* THCState_getStreamOnDevice(THCState* state, int device)
 {
   THCThreadLocal local = state->currentStreams[device];
   THCStream* stream = (THCStream*)THCThreadLocal_get(local);
@@ -443,7 +443,7 @@
   return stream;
 }
 
-static void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream)
+void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream)
 {
   THAssert(stream);
   if (stream->device != device) {
diff --git a/aten/src/THC/THCGeneral.h.in b/aten/src/THC/THCGeneral.h.in
index 7d9f5fc..98f02cf 100644
--- a/aten/src/THC/THCGeneral.h.in
+++ b/aten/src/THC/THCGeneral.h.in
@@ -112,6 +112,8 @@
 THC_API cudaStream_t THCState_getCurrentStream(THCState *state);
 THC_API struct THCStream* THCState_getStream(THCState *state);
 THC_API void THCState_setStream(THCState *state, struct THCStream* stream);
+THC_API THCStream* THCState_getStreamOnDevice(THCState* state, int device);
+THC_API void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream);
 
 THC_API void THCState_reserveBlasHandles(THCState* state, int numHandles);
 THC_API int THCState_getNumBlasHandles(THCState* state);
diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp
index bc99018..d7c3b76 100644
--- a/torch/csrc/cuda/comm.cpp
+++ b/torch/csrc/cuda/comm.cpp
@@ -1,14 +1,21 @@
-#include "comm.h"
+#include <torch/csrc/cuda/comm.h>
 
-#include "torch/csrc/utils/tensor_flatten.h"
-#include "torch/csrc/cuda/device_set.h"
+#include <torch/csrc/cuda/device_set.h>
+#include <torch/csrc/utils/tensor_flatten.h>
+
 #ifdef USE_NCCL
-#include "torch/csrc/cuda/nccl.h"
+#include <torch/csrc/cuda/nccl.h>
 #endif
 
+#include <torch/csrc/utils/auto_stream.h>
+
+#include <THC/THC.h>
+
 #include <ATen/ATen.h>
+#include <ATen/optional.h>
 
 #include <cstddef>
+#include <vector>
 
 namespace torch { namespace cuda {
 
@@ -111,4 +118,85 @@
   return outputs;
 }
 
-}}
+std::vector<at::Tensor> scatter(
+    const at::Tensor& tensor,
+    at::IntList devices,
+    const at::optional<std::vector<int64_t>>& chunk_sizes,
+    int64_t dim,
+    const at::optional<std::vector<THCStream*>>& streams) {
+  std::vector<at::Tensor> chunks;
+  if (chunk_sizes) {
+    const int64_t chunk_size_sum =
+        std::accumulate(chunk_sizes->begin(), chunk_sizes->end(), 0);
+    AT_CHECK(
+      chunk_size_sum == tensor.size(dim),
+      "given chunk sizes don't sum up to the tensor's size ",
+      "(sum(chunk_sizes) == ", chunk_size_sum,
+      ", but expected ", tensor.size(dim), ")");
+    chunks.reserve(chunk_sizes->size());
+    int64_t chunk_start = 0;
+    for (size_t chunk = 0; chunk < chunk_sizes->size(); ++chunk) {
+      const int64_t chunk_size = (*chunk_sizes)[chunk];
+      AT_CHECK(chunk_size > 0, "Chunk size must be positive");
+      chunks.push_back(tensor.narrow(dim, chunk_start, chunk_size));
+      chunk_start += chunk_size;
+    }
+    AT_ASSERT(chunks.size() == chunk_sizes->size());
+  } else {
+    chunks = tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim);
+  }
+  auto* thc_state = at::globalContext().lazyInitCUDA();
+  for (size_t chunk = 0; chunk < chunks.size(); ++chunk) {
+    const int32_t device_index = devices[chunk];
+    // We must set the current device before setting the current stream.
+    const at::DeviceGuard device_guard({at::kCUDA, device_index});
+    const AutoStream stream_guard(
+        streams ? (*streams)[chunk]
+                : THCState_getStreamOnDevice(thc_state, device_index));
+    // Copy the chunk from its current device to its destination device, which
+    // we set as the default device above, thus specified as -1.
+    chunks[chunk] =
+        chunks[chunk].contiguous().to({at::kCUDA, -1}, /*non_blocking=*/true);
+  }
+  return chunks;
+}
+
+at::Tensor gather(
+    at::TensorList tensors,
+    int64_t dim,
+    at::optional<int32_t> destination_index) {
+  AT_ASSERT(!tensors.empty());
+  at::Tensor result;
+  int64_t total_size = 0;
+  auto& first = tensors.front();
+  const auto first_size = first.sizes();
+  std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
+  for (const auto& tensor : tensors) {
+    AT_CHECK(
+        tensor.type().is_cuda(), "Gather expects all inputs to have CUDA type");
+    AT_CHECK(tensor.ndimension() == static_cast<int64_t>(expected_size.size()));
+    expected_size[dim] = tensor.size(dim);
+    for (size_t dimension = 0; dimension < expected_size.size(); ++dimension) {
+      AT_CHECK(
+          expected_size[dimension] == tensor.size(dimension),
+          "Gather got an input of invalid size: got ",
+          tensor.sizes(), ", but expected ", at::IntList(expected_size));
+    }
+    total_size += tensor.size(dim);
+  }
+  expected_size[dim] = total_size;
+  at::Device device(at::kCPU);
+  if (!destination_index || *destination_index != -1) {
+    device = at::Device(at::kCUDA, destination_index ? *destination_index : -1);
+  }
+  result = at::empty(expected_size, first.options().device(device));
+
+  int64_t chunk_start = 0;
+  for (const auto& tensor : tensors) {
+    result.narrow(dim, chunk_start, tensor.size(dim))
+        .copy_(tensor, /*non_blocking=*/true);
+    chunk_start += tensor.size(dim);
+  }
+  return result;
+}
+}} // namespace torch::cuda
diff --git a/torch/csrc/cuda/comm.h b/torch/csrc/cuda/comm.h
index eb0f287..a87cc45 100644
--- a/torch/csrc/cuda/comm.h
+++ b/torch/csrc/cuda/comm.h
@@ -1,7 +1,12 @@
-#include "torch/csrc/assertions.h"
+#pragma once
+
+#include <THC/THC.h>
 
 #include <ATen/ATen.h>
-#include <unordered_map>
+#include <ATen/optional.h>
+
+#include <cstddef>
+#include <vector>
 
 namespace torch { namespace cuda {
 
@@ -11,4 +16,15 @@
 tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntList devices,
                                   size_t buffer_size);
 
+std::vector<at::Tensor> scatter(
+    const at::Tensor& tensor,
+    at::IntList devices,
+    const at::optional<std::vector<int64_t>>& chunk_sizes = at::nullopt,
+    int64_t dim = 0,
+    const at::optional<std::vector<THCStream*>>& streams = at::nullopt);
+
+at::Tensor gather(
+    at::TensorList tensors,
+    int64_t dim,
+    at::optional<int32_t> destination_index);
 }}
diff --git a/torch/csrc/cuda/python_comm.cpp b/torch/csrc/cuda/python_comm.cpp
index 9002556..902d5b9 100644
--- a/torch/csrc/cuda/python_comm.cpp
+++ b/torch/csrc/cuda/python_comm.cpp
@@ -1,10 +1,17 @@
 #include "torch/csrc/utils/pybind.h"
 #include "torch/csrc/cuda/comm.h"
+#include "torch/csrc/cuda/Stream.h"
+#include "torch/csrc/cuda/THCP.h"
+#include "torch/csrc/utils/auto_gil.h"
 
-#include <chrono>
+#include <ATen/ATen.h>
+
+#include <THC/THC.h>
+
+#include <cstddef>
+#include <vector>
 
 namespace torch { namespace cuda { namespace python {
-
 void initCommMethods(PyObject *module) {
   auto m = py::cast<py::module>(module);
   m.def("_broadcast_coalesced", [](std::vector<at::Tensor>& tensors, std::vector<int64_t> devices, size_t buffer_size) {
@@ -13,7 +20,36 @@
       py::call_guard<py::gil_scoped_release>())
    .def("_broadcast", [](at::Tensor& tensor, std::vector<int64_t> devices) {
      return broadcast(tensor, devices);
-   }, py::call_guard<py::gil_scoped_release>());
+   }, py::call_guard<py::gil_scoped_release>())
+   .def("_scatter", [](
+     at::Tensor& tensor,
+     std::vector<int64_t>& devices,
+     at::optional<std::vector<int64_t>> chunk_sizes,
+     int64_t dim,
+     at::optional<py::object> py_streams) {
+     at::optional<std::vector<THCStream*>> streams;
+     if (py_streams) {
+       py::handle handle = *py_streams;
+       streams = THPUtils_PySequence_to_THCStreamList(handle.ptr());
+     }
+     // Note: We're holding the GIL up to here.
+     AutoNoGIL no_gil;
+     return scatter(tensor, devices, chunk_sizes, dim, streams);
+   },
+   py::arg("tensor"),
+   py::arg("devices"),
+   py::arg("chunk_sizes"),
+   py::arg("dim"),
+   py::arg("streams"))
+   .def("_gather", [](
+     std::vector<at::Tensor>& tensors,
+     int64_t dim,
+     at::optional<int32_t> destination_index) {
+     return gather(tensors, dim, destination_index);
+   },
+   py::arg("tensors"),
+   py::arg("dim"),
+   py::arg("destination_index"),
+   py::call_guard<py::gil_scoped_release>());
 }
-
 }}}
diff --git a/torch/cuda/comm.py b/torch/cuda/comm.py
index 919fe8f..ab7f1a2 100644
--- a/torch/cuda/comm.py
+++ b/torch/cuda/comm.py
@@ -139,24 +139,7 @@
         A tuple containing chunks of the ``tensor``, spread across given
         ``devices``.
     """
-    if chunk_sizes is None:
-        chunks = tensor.chunk(len(devices), dim)
-    else:
-        assert sum(chunk_sizes) == tensor.size(dim), "given chunk sizes " \
-            "don't sum up to the tensor's size (sum(chunk_sizes) == {}, but " \
-            "expected {})".format(sum(chunk_sizes), tensor.size(dim))
-        assert min(chunk_sizes) > 0, "got a negative chunk_size"
-        chunks = [tensor.narrow(dim, start - size, size)
-                  for start, size in zip(_accumulate(chunk_sizes), chunk_sizes)]
-    chunks = tuple(chunk.contiguous() for chunk in chunks)
-    # TODO: copy to a pinned buffer first (if copying from CPU)
-    if streams is None:
-        streams = [None] * len(devices)
-    outputs = []
-    for device, chunk, stream in zip(devices, chunks, streams):
-        with torch.cuda.device(device), torch.cuda.stream(stream):
-            outputs.append(chunk.cuda(device, non_blocking=True))
-    return tuple(outputs)
+    return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
 
 
 def gather(tensors, dim=0, destination=None):
@@ -174,30 +157,4 @@
         A tensor located on ``destination`` device, that is a result of
         concatenating ``tensors`` along ``dim``.
     """
-    total_size = 0
-    expected_size = list(tensors[0].size())
-    for tensor in tensors:
-        assert tensor.is_cuda, "gather expects all inputs to be on GPUs"
-        expected_size[dim] = tensor.size(dim)
-        if list(tensor.size()) != expected_size:
-            got = 'x'.join(str(x) for x in tensor.size())
-            expected = 'x'.join(str(x) for x in expected_size)
-            raise ValueError("gather got an input of invalid size: got {}, "
-                             "but expected {}".format(got, expected))
-        total_size += tensor.size(dim)
-    expected_size[dim] = total_size
-    expected_size = torch.Size(expected_size)
-    if destination is None:
-        destination = torch.cuda.current_device()
-    if destination == -1:
-        result = tensors[0].new().cpu().resize_(expected_size)
-    else:
-        result = tensors[0].new(expected_size, device=destination)
-
-    chunk_start = 0
-    # TODO: if copying to CPU, allocate a pinned buffer, do async copies to it,
-    # and copy it to regular memory
-    for tensor in tensors:
-        result.narrow(dim, chunk_start, tensor.size(dim)).copy_(tensor, True)
-        chunk_start += tensor.size(dim)
-    return result
+    return torch._C._gather(tensors, dim, destination)
diff --git a/torch/lib/c10d/private/CUDAUtils.hpp b/torch/lib/c10d/private/CUDAUtils.hpp
index be57725..b2a847b 100644
--- a/torch/lib/c10d/private/CUDAUtils.hpp
+++ b/torch/lib/c10d/private/CUDAUtils.hpp
@@ -7,7 +7,7 @@
 #include <cuda_runtime.h>
 
 #include <ATen/ATen.h>
-#include <THCStream.h>
+#include <THC/THCStream.h>
 
 #include "../CUDAUtils.hpp"