blob: e5dc7992d56c74be4fd1f91815ded223e51f026b [file] [log] [blame]
#include <ATen/core/functional.h>
#include <torch/csrc/cuda/device_set.h>
#include <torch/csrc/cuda/nccl.h>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
#include <nccl.h>
#include <limits>
#include <sstream>
#include <type_traits>
#include <unordered_map>
ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
}
ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
return reinterpret_cast<ncclComm_t>(var);
}
ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) {
return reinterpret_cast<ncclUniqueId*>(var);
}
ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
switch (var) {
case torch::cuda::nccl::ncclResult::Success:
return ncclResult_t::ncclSuccess;
case torch::cuda::nccl::ncclResult::UnhandledCudaError:
return ncclResult_t::ncclUnhandledCudaError;
case torch::cuda::nccl::ncclResult::SystemError:
return ncclResult_t::ncclSystemError;
case torch::cuda::nccl::ncclResult::InternalError:
return ncclResult_t::ncclInternalError;
case torch::cuda::nccl::ncclResult::InvalidArgument:
return ncclResult_t::ncclInvalidArgument;
case torch::cuda::nccl::ncclResult::InvalidUsage:
return ncclResult_t::ncclInvalidUsage;
case torch::cuda::nccl::ncclResult::NumResults:
return ncclResult_t::ncclNumResults;
default:
throw std::runtime_error("Unconvertible NCCL type");
}
}
torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
switch (var) {
case ncclSuccess:
return torch::cuda::nccl::ncclResult::Success;
case ncclUnhandledCudaError:
return torch::cuda::nccl::ncclResult::UnhandledCudaError;
case ncclSystemError:
return torch::cuda::nccl::ncclResult::SystemError;
case ncclInternalError:
return torch::cuda::nccl::ncclResult::InternalError;
case ncclInvalidArgument:
return torch::cuda::nccl::ncclResult::InvalidArgument;
case ncclInvalidUsage:
return torch::cuda::nccl::ncclResult::InvalidUsage;
case ncclNumResults:
return torch::cuda::nccl::ncclResult::NumResults;
default:
throw std::runtime_error("Unconvertible NCCL type");
}
}
ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
switch (type) {
case at::kFloat:
return ncclDataType_t::ncclFloat;
case at::kHalf:
return ncclDataType_t::ncclHalf;
case at::kDouble:
return ncclDataType_t::ncclDouble;
case at::kLong:
return ncclDataType_t::ncclInt64;
case at::kInt:
return ncclDataType_t::ncclInt;
case at::kChar:
return ncclDataType_t::ncclChar;
case at::kByte:
return ncclDataType_t::ncclUint8;
case at::kBool:
return ncclDataType_t::ncclUint8;
#if HAS_NCCL_BF16_DATATYPE
case at::kBFloat16:
return ncclDataType_t::ncclBfloat16;
#endif
default:
TORCH_CHECK(false, "Unconvertible NCCL type ", type);
}
}
ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
if (!t.is_cuda()) {
TORCH_CHECK(
false,
"NCCL only supports CUDA tensors, but got a tensor on ",
t.device());
}
return to_nccl_data_type(t.scalar_type());
}
ncclRedOp_t to_nccl_red_op(int var) {
return (ncclRedOp_t)(var);
}
namespace torch {
namespace cuda {
namespace nccl {
using namespace at;
namespace detail {
static inline void NCCL_CHECK(ncclResult_t result) {
NCCL_CHECK(from_nccl_result(result));
}
struct AutoNcclGroup {
AutoNcclGroup() {
(c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(ncclGroupStart());
#endif
}
~AutoNcclGroup() {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(ncclGroupEnd());
#endif
(c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
}
};
void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
std::ostringstream err;
err << "NCCL Error " << static_cast<int>(status) << ": "
<< ncclGetErrorString(to_nccl_result(status));
throw std::runtime_error(err.str());
}
struct NcclCommList {
std::unique_ptr<ncclComm_t[]> comms;
int ndevices;
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
NCCL_CHECK(ncclCommInitAll(
to_nccl_comm(comms.get()), devices.size(), devices.data()));
}
NcclCommList(NcclCommList&& foo) = default;
~NcclCommList() {
if (comms) {
for (const auto i : c10::irange(ndevices)) {
int dummy_var;
if (cudaGetDevice(&dummy_var) != cudaSuccess) {
/* there are cases when this destructor is called after the
CUDA driver is already unloaded from the process.
In these cases, skip ncclCommDestroy */
return;
}
comm_destroy(comms[i]);
}
}
}
ArrayRef<ncclComm_t> ref() const {
return ArrayRef<ncclComm_t>(comms.get(), ndevices);
}
};
using device_list = std::vector<int>;
// accesses to this object have to be guarded by THC's CudaFreeMutex
static std::unordered_map<device_list, NcclCommList, c10::hash<device_list>>
_communicators;
ArrayRef<ncclComm_t> get_communicators(TensorList inputs) {
static auto get_device = [](const at::Tensor& t) -> int {
return t.get_device();
};
device_list devices = fmap(inputs, get_device);
auto it = _communicators.find(devices);
if (it == _communicators.end())
std::tie(it, std::ignore) = _communicators.emplace(devices, devices);
return it->second.ref();
}
static inline void check_tensor(
const at::Tensor& input,
const at::optional<at::Tensor>& output,
int input_multiplier,
int output_multiplier,
int64_t ref_numel,
ScalarType ref_dtype) {
auto check_one = [&](const at::Tensor& tensor) {
if (!tensor.is_cuda() || tensor.is_sparse()) {
throw std::runtime_error(
"input and output elements have to be cuda dense Tensors");
}
if (ref_dtype != tensor.scalar_type()) {
throw std::runtime_error(
"all inputs and outputs must be of the same Tensor dtype");
}
if (!tensor.is_contiguous()) {
throw std::runtime_error("all inputs and outputs have to be contiguous");
}
};
check_one(input);
// all inputs must be same size
if (input.numel() != ref_numel) {
throw std::runtime_error(
"all inputs must have the same number of elements");
}
if (output) {
check_one(*output);
// inputs and outputs must be on same device respectively
if (input.get_device() != output->get_device()) {
throw std::runtime_error("input and output must be on the same device");
}
if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
throw std::runtime_error(
"output must be of size input_size * size_multiplier");
}
}
}
void check_inputs(
TensorList inputs,
TensorList outputs,
int input_multiplier,
int output_multiplier) {
// len(inputs) == len(outputs)
size_t len = inputs.size();
if (len <= 0) {
throw std::runtime_error("input sequence can't be empty");
}
if (len != outputs.size()) {
std::stringstream err;
err << "inputs and outputs sequences have to be of the same length, but got input of length "
<< len << " and output of length " << outputs.size();
throw std::runtime_error(err.str());
}
device_set devices;
int64_t numel = inputs[0].numel();
auto dtype = inputs[0].scalar_type();
for (const auto i : c10::irange(len)) {
auto input = inputs[i];
auto output = outputs[i];
check_tensor(
input, output, input_multiplier, output_multiplier, numel, dtype);
auto input_device = input.get_device();
// inputs must be on unique devices
if (devices.test(input_device)) {
throw std::runtime_error("inputs must be on unique devices");
}
devices.set(input_device);
}
}
void check_inputs(
TensorList inputs,
const at::Tensor& output,
int root,
int input_multiplier,
int output_multiplier) {
size_t len = inputs.size();
if (len <= 0) {
throw std::runtime_error("input sequence can't be empty");
}
device_set devices;
int64_t numel = inputs[0].numel();
auto dtype = inputs[0].scalar_type();
for (const auto i : c10::irange(len)) {
auto input = inputs[i];
check_tensor(
input,
i == root ? at::optional<at::Tensor>{output} : at::nullopt,
input_multiplier,
output_multiplier,
numel,
dtype);
auto input_device = input.get_device();
// inputs must be on unique devices
if (devices.test(input_device)) {
throw std::runtime_error("inputs must be on unique devices");
}
devices.set(input_device);
}
}
} // namespace detail
bool is_available(TensorList tensors) {
#ifdef USE_NCCL
device_set devices;
for (auto& tensor : tensors) {
if (!tensor.is_cuda() || tensor.is_sparse())
return false;
if (!tensor.is_contiguous())
return false;
auto device = tensor.get_device();
if (devices[device])
return false;
devices[device] = true;
}
return true;
#else
return false;
#endif
}
std::uint64_t version() {
#if defined(NCCL_MAJOR)
constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) |
(((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH);
return ver;
#elif defined(USE_NCCL)
// return major version "1"
return ((uint64_t)1) << 32;
#else
return 0;
#endif
}
void get_unique_id(ncclUniqueId& id) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id)));
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
ncclComm_t comm;
ncclUniqueId id = comm_id;
NCCL_CHECK(ncclCommInitRank(
to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank));
return comm;
#else
return nullptr;
#endif
}
void comm_destroy(ncclComm_t comm) {
/*
* TODO(T30279827) Temporarily disable calling ncclCommDestroy
* Calling ncclCommDestroy while program exiting is undefined
* according to Nvidia, and lead to segfault in NCCL 2
* (whether it is called before or after the CUDA runtime destructor).
* Temporarily disable it in destructor to avoid segfault.
* Following up with Nvidia for long term solution.
*/
return;
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm)));
#endif
}
namespace {
// NCCL changed the numerical type used for count between NCCL1 and NCCL2.
// So we use the following struct, which gets the type of the second argument
// of T, if T is a function type, with ncclBcast, to get that type statically
// and programmatically.
template <typename T>
struct GetSecondArgType;
template <typename R, typename Arg0, typename Arg1, typename... Args>
struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
typedef typename std::decay<Arg1>::type type;
};
constexpr auto count_max =
std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
} // namespace
size_t get_max_count() {
return count_max;
}
void broadcast(
TensorList tensors,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
check_inputs(tensors, tensors, 1, 1);
auto data_type = to_nccl_data_type(tensors[0]);
int64_t numel = tensors[0].numel();
const auto comms = user_comms.empty() ? get_communicators(tensors)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
int device = tensors[i].get_device();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
TORCH_CHECK(
static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
"Broadcast tensor has ",
numel,
" elements, which exceeds the "
"maximum NCCL supports (",
count_max,
")");
ncclComm_t comm = comms[i];
NCCL_CHECK(ncclBcast(
tensors[i].data_ptr(),
numel,
data_type,
0,
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void reduce(
const std::vector<at::Tensor>& inputs,
at::Tensor& output,
int32_t root,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
TORCH_CHECK(
root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root");
check_inputs(inputs, output, root, 1, 1);
const auto len = inputs.size();
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel();
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclReduce(
inputs[i].data_ptr(),
root == i ? output.data_ptr() : nullptr,
count,
data_type,
to_nccl_red_op(op),
root,
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void reduce(
std::vector<at::Tensor>& inputs,
int32_t root,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
reduce(inputs, /*output=*/inputs[root], root, op, streams, user_comms);
}
void all_reduce(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
check_inputs(inputs, outputs, 1, 1);
const auto len = inputs.size();
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel();
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclAllReduce(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void reduce_scatter(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
const auto len = inputs.size();
check_inputs(inputs, outputs, 1, len);
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel() / len;
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclReduceScatter(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all_gather(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
const auto len = inputs.size();
check_inputs(inputs, outputs, len, 1);
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel();
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_comm(comm),
stream));
#else
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
count,
data_type,
outputs[i].data_ptr(),
to_nccl_comm(comm),
stream));
#endif
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all2all_single_equal_split(
at::Tensor& input,
at::Tensor& output,
int size,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
int numranks;
auto type = to_nccl_data_type(input);
size_t count = input.numel() / size;
size_t rankdiff = input.nbytes() / size;
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
auto comm = to_nccl_comm(_comm);
#if defined(USE_ROCM) && ROCM_VERSION >= 50000
NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream));
#else
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for (const auto r : c10::irange(numranks)) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (count != 0) {
NCCL_CHECK(
ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream));
NCCL_CHECK(
ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
}
}
NCCL_CHECK(ncclGroupEnd());
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all2all_single_unequal_split(
void* sendbuff,
const size_t* sendcounts,
const size_t* senddispls,
void* recvbuff,
const size_t* recvcounts,
const size_t* recvdispls,
size_t size,
c10::ScalarType _type,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto type = to_nccl_data_type(_type);
auto comm = to_nccl_comm(_comm);
int numranks;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for (const auto r : c10::irange(numranks)) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (sendcounts[r] != 0) {
NCCL_CHECK(ncclSend(
((char*)sendbuff) + senddispls[r] * size,
sendcounts[r],
type,
r,
comm,
stream));
}
if (recvcounts[r] != 0) {
NCCL_CHECK(ncclRecv(
((char*)recvbuff) + recvdispls[r] * size,
recvcounts[r],
type,
r,
comm,
stream));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all2all(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
NCCL_CHECK(ncclGroupStart());
for (const auto r : c10::irange(outputTensors.size())) {
at::Tensor& input = inputTensors[r];
at::Tensor& output = outputTensors[r];
if (input.numel() != 0) {
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
r,
comm,
stream.stream()));
}
if (output.numel() != 0) {
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
r,
comm,
stream.stream()));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void send(
const at::Tensor& input,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int dst) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
to_nccl_comm(comm),
stream.stream()));
#else
AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void recv(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int src) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
to_nccl_comm(comm),
stream.stream()));
#else
AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void gather(
const at::Tensor& inputs,
std::vector<at::Tensor>& outputs,
ncclComm_t _comm,
at::cuda::CUDAStream& stream,
int32_t root) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
int numranks, cur_rank;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
size_t count = inputs.numel();
auto type = to_nccl_data_type(inputs);
const auto* sendbuff = reinterpret_cast<char*>(inputs.data_ptr());
NCCL_CHECK(ncclGroupStart());
if (cur_rank == root) {
for (const auto r : c10::irange(numranks)) {
if (r != root) {
auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
} else {
// on its own rank, simply copy from the input
outputs[r].copy_(inputs);
}
}
} else {
NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void scatter(
const std::vector<at::Tensor>& inputs,
at::Tensor& outputs,
ncclComm_t _comm,
at::cuda::CUDAStream& stream,
int32_t root) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
int numranks, cur_rank;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
NCCL_CHECK(ncclGroupStart());
if (cur_rank == root) {
for (const auto r : c10::irange(numranks)) {
if (r != root) {
size_t send_count = inputs[r].numel();
auto send_type = to_nccl_data_type(inputs[r]);
const auto* sendbuff = reinterpret_cast<char*>(inputs[r].data_ptr());
NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream));
} else {
// on its own rank, simply copy it to the output
outputs.copy_(inputs[r]);
}
}
} else {
size_t recv_count = outputs.numel();
auto recv_type = to_nccl_data_type(outputs);
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
} // namespace nccl
} // namespace cuda
} // namespace torch