Move nccl scatter and gather to C++ (#9117)
Summary:
As I try to replicate DP in C++, I need to move some functions into C++ from Python. This PR ports the scatter and gather primitives from Python in torch/cuda/comm.py to C++ in torch/csrc/cuda/comm.cpp. The basic infrastructure was already there, since apaszke had rewritten broadcast in C++ already.
I'm not very familiar with this code, so let me know if I'm doing something wrong. I largely just literally translated the code.
I don't know how "public" `torch.cuda.comm` is, but I feel like the `destination_index` parameter for `gather` should be changed from -1 indicating CPU to `None` indicating CPU, and `-1` indicating the default CUDA device. That would make the code clearer IMO.
apaszke colesbury teng-li pietern
Closes https://github.com/pytorch/pytorch/pull/9117
Differential Revision: D8721729
Pulled By: goldsborough
fbshipit-source-id: 1844a488079d21fa209b32e2c73e48632cbe9e68
diff --git a/aten/src/THC/THCGeneral.cpp b/aten/src/THC/THCGeneral.cpp
index 6f4fdf0..2f2258b 100644
--- a/aten/src/THC/THCGeneral.cpp
+++ b/aten/src/THC/THCGeneral.cpp
@@ -431,7 +431,7 @@
return res->sparseHandles[handle - 1];
}
-static THCStream* THCState_getStreamOnDevice(THCState* state, int device)
+THCStream* THCState_getStreamOnDevice(THCState* state, int device)
{
THCThreadLocal local = state->currentStreams[device];
THCStream* stream = (THCStream*)THCThreadLocal_get(local);
@@ -443,7 +443,7 @@
return stream;
}
-static void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream)
+void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream)
{
THAssert(stream);
if (stream->device != device) {
diff --git a/aten/src/THC/THCGeneral.h.in b/aten/src/THC/THCGeneral.h.in
index 7d9f5fc..98f02cf 100644
--- a/aten/src/THC/THCGeneral.h.in
+++ b/aten/src/THC/THCGeneral.h.in
@@ -112,6 +112,8 @@
THC_API cudaStream_t THCState_getCurrentStream(THCState *state);
THC_API struct THCStream* THCState_getStream(THCState *state);
THC_API void THCState_setStream(THCState *state, struct THCStream* stream);
+THC_API THCStream* THCState_getStreamOnDevice(THCState* state, int device);
+THC_API void THCState_setStreamOnDevice(THCState *state, int device, THCStream *stream);
THC_API void THCState_reserveBlasHandles(THCState* state, int numHandles);
THC_API int THCState_getNumBlasHandles(THCState* state);
diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp
index bc99018..d7c3b76 100644
--- a/torch/csrc/cuda/comm.cpp
+++ b/torch/csrc/cuda/comm.cpp
@@ -1,14 +1,21 @@
-#include "comm.h"
+#include <torch/csrc/cuda/comm.h>
-#include "torch/csrc/utils/tensor_flatten.h"
-#include "torch/csrc/cuda/device_set.h"
+#include <torch/csrc/cuda/device_set.h>
+#include <torch/csrc/utils/tensor_flatten.h>
+
#ifdef USE_NCCL
-#include "torch/csrc/cuda/nccl.h"
+#include <torch/csrc/cuda/nccl.h>
#endif
+#include <torch/csrc/utils/auto_stream.h>
+
+#include <THC/THC.h>
+
#include <ATen/ATen.h>
+#include <ATen/optional.h>
#include <cstddef>
+#include <vector>
namespace torch { namespace cuda {
@@ -111,4 +118,85 @@
return outputs;
}
-}}
+std::vector<at::Tensor> scatter(
+ const at::Tensor& tensor,
+ at::IntList devices,
+ const at::optional<std::vector<int64_t>>& chunk_sizes,
+ int64_t dim,
+ const at::optional<std::vector<THCStream*>>& streams) {
+ std::vector<at::Tensor> chunks;
+ if (chunk_sizes) {
+ const int64_t chunk_size_sum =
+ std::accumulate(chunk_sizes->begin(), chunk_sizes->end(), 0);
+ AT_CHECK(
+ chunk_size_sum == tensor.size(dim),
+ "given chunk sizes don't sum up to the tensor's size ",
+ "(sum(chunk_sizes) == ", chunk_size_sum,
+ ", but expected ", tensor.size(dim), ")");
+ chunks.reserve(chunk_sizes->size());
+ int64_t chunk_start = 0;
+ for (size_t chunk = 0; chunk < chunk_sizes->size(); ++chunk) {
+ const int64_t chunk_size = (*chunk_sizes)[chunk];
+ AT_CHECK(chunk_size > 0, "Chunk size must be positive");
+ chunks.push_back(tensor.narrow(dim, chunk_start, chunk_size));
+ chunk_start += chunk_size;
+ }
+ AT_ASSERT(chunks.size() == chunk_sizes->size());
+ } else {
+ chunks = tensor.chunk(/*chunks=*/devices.size(), /*dim=*/dim);
+ }
+ auto* thc_state = at::globalContext().lazyInitCUDA();
+ for (size_t chunk = 0; chunk < chunks.size(); ++chunk) {
+ const int32_t device_index = devices[chunk];
+ // We must set the current device before setting the current stream.
+ const at::DeviceGuard device_guard({at::kCUDA, device_index});
+ const AutoStream stream_guard(
+ streams ? (*streams)[chunk]
+ : THCState_getStreamOnDevice(thc_state, device_index));
+ // Copy the chunk from its current device to its destination device, which
+ // we set as the default device above, thus specified as -1.
+ chunks[chunk] =
+ chunks[chunk].contiguous().to({at::kCUDA, -1}, /*non_blocking=*/true);
+ }
+ return chunks;
+}
+
+at::Tensor gather(
+ at::TensorList tensors,
+ int64_t dim,
+ at::optional<int32_t> destination_index) {
+ AT_ASSERT(!tensors.empty());
+ at::Tensor result;
+ int64_t total_size = 0;
+ auto& first = tensors.front();
+ const auto first_size = first.sizes();
+ std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
+ for (const auto& tensor : tensors) {
+ AT_CHECK(
+ tensor.type().is_cuda(), "Gather expects all inputs to have CUDA type");
+ AT_CHECK(tensor.ndimension() == static_cast<int64_t>(expected_size.size()));
+ expected_size[dim] = tensor.size(dim);
+ for (size_t dimension = 0; dimension < expected_size.size(); ++dimension) {
+ AT_CHECK(
+ expected_size[dimension] == tensor.size(dimension),
+ "Gather got an input of invalid size: got ",
+ tensor.sizes(), ", but expected ", at::IntList(expected_size));
+ }
+ total_size += tensor.size(dim);
+ }
+ expected_size[dim] = total_size;
+ at::Device device(at::kCPU);
+ if (!destination_index || *destination_index != -1) {
+ device = at::Device(at::kCUDA, destination_index ? *destination_index : -1);
+ }
+ result = at::empty(expected_size, first.options().device(device));
+
+ int64_t chunk_start = 0;
+ for (const auto& tensor : tensors) {
+ result.narrow(dim, chunk_start, tensor.size(dim))
+ .copy_(tensor, /*non_blocking=*/true);
+ chunk_start += tensor.size(dim);
+ }
+ return result;
+}
+}} // namespace torch::cuda
diff --git a/torch/csrc/cuda/comm.h b/torch/csrc/cuda/comm.h
index eb0f287..a87cc45 100644
--- a/torch/csrc/cuda/comm.h
+++ b/torch/csrc/cuda/comm.h
@@ -1,7 +1,12 @@
-#include "torch/csrc/assertions.h"
+#pragma once
+
+#include <THC/THC.h>
#include <ATen/ATen.h>
-#include <unordered_map>
+#include <ATen/optional.h>
+
+#include <cstddef>
+#include <vector>
namespace torch { namespace cuda {
@@ -11,4 +16,15 @@
tensor_list2d broadcast_coalesced(at::TensorList tensors, at::IntList devices,
size_t buffer_size);
+std::vector<at::Tensor> scatter(
+ const at::Tensor& tensor,
+ at::IntList devices,
+ const at::optional<std::vector<int64_t>>& chunk_sizes = at::nullopt,
+ int64_t dim = 0,
+ const at::optional<std::vector<THCStream*>>& streams = at::nullopt);
+
+at::Tensor gather(
+ at::TensorList tensors,
+ int64_t dim,
+ at::optional<int32_t> destination_index);
}}
diff --git a/torch/csrc/cuda/python_comm.cpp b/torch/csrc/cuda/python_comm.cpp
index 9002556..902d5b9 100644
--- a/torch/csrc/cuda/python_comm.cpp
+++ b/torch/csrc/cuda/python_comm.cpp
@@ -1,10 +1,17 @@
#include "torch/csrc/utils/pybind.h"
#include "torch/csrc/cuda/comm.h"
+#include "torch/csrc/cuda/Stream.h"
+#include "torch/csrc/cuda/THCP.h"
+#include "torch/csrc/utils/auto_gil.h"
-#include <chrono>
+#include <ATen/ATen.h>
+
+#include <THC/THC.h>
+
+#include <cstddef>
+#include <vector>
namespace torch { namespace cuda { namespace python {
-
void initCommMethods(PyObject *module) {
auto m = py::cast<py::module>(module);
m.def("_broadcast_coalesced", [](std::vector<at::Tensor>& tensors, std::vector<int64_t> devices, size_t buffer_size) {
@@ -13,7 +20,36 @@
py::call_guard<py::gil_scoped_release>())
.def("_broadcast", [](at::Tensor& tensor, std::vector<int64_t> devices) {
return broadcast(tensor, devices);
- }, py::call_guard<py::gil_scoped_release>());
+ }, py::call_guard<py::gil_scoped_release>())
+ .def("_scatter", [](
+ at::Tensor& tensor,
+ std::vector<int64_t>& devices,
+ at::optional<std::vector<int64_t>> chunk_sizes,
+ int64_t dim,
+ at::optional<py::object> py_streams) {
+ at::optional<std::vector<THCStream*>> streams;
+ if (py_streams) {
+ py::handle handle = *py_streams;
+ streams = THPUtils_PySequence_to_THCStreamList(handle.ptr());
+ }
+ // Note: We're holding the GIL up to here.
+ AutoNoGIL no_gil;
+ return scatter(tensor, devices, chunk_sizes, dim, streams);
+ },
+ py::arg("tensor"),
+ py::arg("devices"),
+ py::arg("chunk_sizes"),
+ py::arg("dim"),
+ py::arg("streams"))
+ .def("_gather", [](
+ std::vector<at::Tensor>& tensors,
+ int64_t dim,
+ at::optional<int32_t> destination_index) {
+ return gather(tensors, dim, destination_index);
+ },
+ py::arg("tensors"),
+ py::arg("dim"),
+ py::arg("destination_index"),
+ py::call_guard<py::gil_scoped_release>());
}
-
}}}
diff --git a/torch/cuda/comm.py b/torch/cuda/comm.py
index 919fe8f..ab7f1a2 100644
--- a/torch/cuda/comm.py
+++ b/torch/cuda/comm.py
@@ -139,24 +139,7 @@
A tuple containing chunks of the ``tensor``, spread across given
``devices``.
"""
- if chunk_sizes is None:
- chunks = tensor.chunk(len(devices), dim)
- else:
- assert sum(chunk_sizes) == tensor.size(dim), "given chunk sizes " \
- "don't sum up to the tensor's size (sum(chunk_sizes) == {}, but " \
- "expected {})".format(sum(chunk_sizes), tensor.size(dim))
- assert min(chunk_sizes) > 0, "got a negative chunk_size"
- chunks = [tensor.narrow(dim, start - size, size)
- for start, size in zip(_accumulate(chunk_sizes), chunk_sizes)]
- chunks = tuple(chunk.contiguous() for chunk in chunks)
- # TODO: copy to a pinned buffer first (if copying from CPU)
- if streams is None:
- streams = [None] * len(devices)
- outputs = []
- for device, chunk, stream in zip(devices, chunks, streams):
- with torch.cuda.device(device), torch.cuda.stream(stream):
- outputs.append(chunk.cuda(device, non_blocking=True))
- return tuple(outputs)
+ return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
def gather(tensors, dim=0, destination=None):
@@ -174,30 +157,4 @@
A tensor located on ``destination`` device, that is a result of
concatenating ``tensors`` along ``dim``.
"""
- total_size = 0
- expected_size = list(tensors[0].size())
- for tensor in tensors:
- assert tensor.is_cuda, "gather expects all inputs to be on GPUs"
- expected_size[dim] = tensor.size(dim)
- if list(tensor.size()) != expected_size:
- got = 'x'.join(str(x) for x in tensor.size())
- expected = 'x'.join(str(x) for x in expected_size)
- raise ValueError("gather got an input of invalid size: got {}, "
- "but expected {}".format(got, expected))
- total_size += tensor.size(dim)
- expected_size[dim] = total_size
- expected_size = torch.Size(expected_size)
- if destination is None:
- destination = torch.cuda.current_device()
- if destination == -1:
- result = tensors[0].new().cpu().resize_(expected_size)
- else:
- result = tensors[0].new(expected_size, device=destination)
-
- chunk_start = 0
- # TODO: if copying to CPU, allocate a pinned buffer, do async copies to it,
- # and copy it to regular memory
- for tensor in tensors:
- result.narrow(dim, chunk_start, tensor.size(dim)).copy_(tensor, True)
- chunk_start += tensor.size(dim)
- return result
+ return torch._C._gather(tensors, dim, destination)
diff --git a/torch/lib/c10d/private/CUDAUtils.hpp b/torch/lib/c10d/private/CUDAUtils.hpp
index be57725..b2a847b 100644
--- a/torch/lib/c10d/private/CUDAUtils.hpp
+++ b/torch/lib/c10d/private/CUDAUtils.hpp
@@ -7,7 +7,7 @@
#include <cuda_runtime.h>
#include <ATen/ATen.h>
-#include <THCStream.h>
+#include <THC/THCStream.h>
#include "../CUDAUtils.hpp"