|  | #include <torch/csrc/cuda/nccl.h> | 
|  | #include <torch/csrc/cuda/device_set.h> | 
|  | #include <ATen/core/functional.h> | 
|  | #include <torch/csrc/utils/hash.h> | 
|  |  | 
|  | #include <ATen/ATen.h> | 
|  | #include <c10/cuda/CUDAGuard.h> | 
|  | #include <c10/util/Exception.h> | 
|  |  | 
|  | #include <THC/THC.h> | 
|  |  | 
|  | #include <limits> | 
|  | #include <sstream> | 
|  | #include <type_traits> | 
|  | #include <unordered_map> | 
|  |  | 
|  | namespace torch { | 
|  | namespace cuda { | 
|  | namespace nccl { | 
|  |  | 
|  | using namespace at; | 
|  |  | 
|  | namespace detail { | 
|  |  | 
|  | void throw_nccl_error(ncclResult_t status) { | 
|  | std::ostringstream err; | 
|  | err << "NCCL Error " << status << ": " << ncclGetErrorString(status); | 
|  | throw std::runtime_error(err.str()); | 
|  | } | 
|  |  | 
|  | struct NcclCommList { | 
|  | std::unique_ptr<ncclComm_t[]> comms; | 
|  | int ndevices; | 
|  | NcclCommList(const std::vector<int>& devices) | 
|  | : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) { | 
|  | NCCL_CHECK(ncclCommInitAll(comms.get(), devices.size(), devices.data())); | 
|  | } | 
|  | NcclCommList(NcclCommList&& foo) = default; | 
|  | ~NcclCommList() { | 
|  | /* | 
|  | * TODO(T30279827) Temporarily disable calling ncclCommDestroy | 
|  | * Calling ncclCommDestroy while program exiting is undefined | 
|  | * according to Nvidia, and lead to segfault in NCCL 2 | 
|  | * (whether it is called before or after the CUDA runtime destructor). | 
|  | * Temporarily disable it in destructor to avoid segfault. | 
|  | * Following up with Nvidia for long term solution. | 
|  | */ | 
|  | return; | 
|  |  | 
|  | if (comms) { | 
|  | for (int i = 0; i < ndevices; i++) { | 
|  | int dummy_var; | 
|  | if (cudaGetDevice(&dummy_var) != cudaSuccess) { | 
|  | /* there are cases when this destructor is called after the | 
|  | CUDA driver is already unloaded from the process. | 
|  | In these cases, skip ncclCommDestroy */ | 
|  | return; | 
|  | } | 
|  | ncclCommDestroy(comms[i]); | 
|  | } | 
|  | } | 
|  | } | 
|  | ArrayRef<ncclComm_t> ref() const { | 
|  | return ArrayRef<ncclComm_t>(comms.get(), ndevices); | 
|  | } | 
|  | }; | 
|  |  | 
|  | using device_list = std::vector<int>; | 
|  | // accesses to this object have to be guarded by THC's CudaFreeMutex | 
|  | static std::unordered_map<device_list, NcclCommList, torch::hash<device_list>> | 
|  | _communicators; | 
|  |  | 
|  | ArrayRef<ncclComm_t> get_communicators(TensorList inputs) { | 
|  | static auto get_device = [](const at::Tensor& t) -> int { | 
|  | return t.get_device(); | 
|  | }; | 
|  | device_list devices = fmap(inputs, get_device); | 
|  | auto it = _communicators.find(devices); | 
|  | if (it == _communicators.end()) | 
|  | std::tie(it, std::ignore) = _communicators.emplace(devices, devices); | 
|  | return it->second.ref(); | 
|  | } | 
|  |  | 
|  | ncclDataType_t get_data_type(const Tensor& t) { | 
|  | if (t.type().backend() != Backend::CUDA) { | 
|  | throw std::runtime_error("Unconvertible NCCL type"); | 
|  | } | 
|  | switch (t.scalar_type()) { | 
|  | case at::kFloat: | 
|  | return ncclFloat; | 
|  | case at::kHalf: | 
|  | return ncclHalf; | 
|  | case at::kDouble: | 
|  | return ncclDouble; | 
|  | case at::kLong: | 
|  | return ncclInt64; | 
|  | case at::kInt: | 
|  | return ncclInt; | 
|  | case at::kChar: | 
|  | return ncclChar; | 
|  | case at::kByte: | 
|  | return ncclChar; | 
|  | default: | 
|  | throw std::runtime_error("Unconvertible NCCL type"); | 
|  | } | 
|  | } | 
|  |  | 
|  | void check_inputs( | 
|  | TensorList inputs, | 
|  | TensorList outputs, | 
|  | int input_multiplier, | 
|  | int output_multiplier) { | 
|  | // len(inputs) == len(outputs) | 
|  | size_t len = inputs.size(); | 
|  |  | 
|  | if (len <= 0) { | 
|  | throw std::runtime_error("input sequence can't be empty"); | 
|  | } | 
|  |  | 
|  | if (len != outputs.size()) { | 
|  | std::stringstream err; | 
|  | err << "inputs and outputs sequences have to be of the same length, but got input of length " | 
|  | << len << " and output of length " << outputs.size(); | 
|  | throw std::runtime_error(err.str()); | 
|  | } | 
|  |  | 
|  | device_set devices; | 
|  | int64_t numel = inputs[0].numel(); | 
|  | auto type = inputs[0].type(); | 
|  |  | 
|  | for (size_t i = 0; i < len; i++) { | 
|  | auto input = inputs[i]; | 
|  | auto output = outputs[i]; | 
|  |  | 
|  | if (!(input.is_cuda() && !input.is_sparse() && | 
|  | output.is_cuda() && !output.is_sparse())) { | 
|  | throw std::runtime_error( | 
|  | "input and output elements have to be cuda dense Tensors"); | 
|  | } | 
|  |  | 
|  | if (!(type == input.type() && type == output.type())) { | 
|  | throw std::runtime_error( | 
|  | "all inputs and outputs must be of the same Tensor type"); | 
|  | } | 
|  |  | 
|  | if (!input.is_contiguous() || !output.is_contiguous()) { | 
|  | throw std::runtime_error("all inputs and outputs have to be contiguous"); | 
|  | } | 
|  |  | 
|  | auto input_device = input.get_device(); | 
|  | // inputs must be on unique devices | 
|  | if (devices.test(input_device)) { | 
|  | throw std::runtime_error("inputs must be on unique devices"); | 
|  | } | 
|  | devices.set(input_device); | 
|  |  | 
|  | // inputs and outputs must be on same device respectively | 
|  | if (input_device != output.get_device()) { | 
|  | throw std::runtime_error("input and output must be on the same device"); | 
|  | } | 
|  |  | 
|  | // all inputs must be same size | 
|  | if (input.numel() != numel) { | 
|  | throw std::runtime_error( | 
|  | "all inputs must have the same number of elements"); | 
|  | } | 
|  |  | 
|  | if (output.numel() * output_multiplier != numel * input_multiplier) { | 
|  | throw std::runtime_error( | 
|  | "output must be of size input_size * size_multiplier"); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | } // namespace detail | 
|  |  | 
|  | bool is_available(TensorList tensors) { | 
|  | #ifdef USE_NCCL | 
|  | device_set devices; | 
|  | for (auto& tensor : tensors) { | 
|  | auto type = tensor.type(); | 
|  | if (!type.is_cuda() || type.is_sparse()) | 
|  | return false; | 
|  | if (!tensor.is_contiguous()) | 
|  | return false; | 
|  | auto device = tensor.get_device(); | 
|  | if (devices[device]) | 
|  | return false; | 
|  | devices[device] = true; | 
|  | } | 
|  | return true; | 
|  | #else | 
|  | return false; | 
|  | #endif | 
|  | } | 
|  |  | 
|  | std::uint64_t version() { | 
|  | #if defined(NCCL_MAJOR) | 
|  | return NCCL_MAJOR * 1000 + NCCL_MINOR * 100 + NCCL_PATCH; | 
|  | #elif defined(USE_NCCL) | 
|  | return 1000; | 
|  | #else | 
|  | return 0; | 
|  | #endif | 
|  | } | 
|  |  | 
|  | namespace { | 
|  | // NCCL changed the numerical type used for count between NCCL1 and NCCL2. | 
|  | // So we use the following struct, which gets the type of the second argument | 
|  | // of T, if T is a function type, with ncclBcast, to get that type statically | 
|  | // and programmatically. | 
|  |  | 
|  | template <typename T> | 
|  | struct GetSecondArgType; | 
|  |  | 
|  | template <typename R, typename Arg0, typename Arg1, typename... Args> | 
|  | struct GetSecondArgType<R(Arg0, Arg1, Args...)> { | 
|  | typedef typename std::decay<Arg1>::type type; | 
|  | }; | 
|  |  | 
|  | constexpr auto count_max = | 
|  | std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max(); | 
|  | } // namespace | 
|  |  | 
|  | size_t get_max_count() { | 
|  | return count_max; | 
|  | } | 
|  |  | 
|  | void broadcast( | 
|  | TensorList tensors, | 
|  | const stream_list& streams, | 
|  | const comm_list& user_comms) { | 
|  | #ifdef USE_NCCL | 
|  | using namespace torch::cuda::nccl::detail; | 
|  | check_inputs(tensors, tensors, 1, 1); | 
|  | ncclDataType_t data_type = get_data_type(tensors[0]); | 
|  | int64_t numel = tensors[0].numel(); | 
|  |  | 
|  | AutoNcclGroup nccl_group_guard; | 
|  | const auto comms = user_comms.empty() ? get_communicators(tensors) | 
|  | : ArrayRef<ncclComm_t>(user_comms); | 
|  |  | 
|  | at::cuda::OptionalCUDAGuard device_guard; | 
|  | for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) { | 
|  | int device = tensors[i].get_device(); | 
|  | device_guard.set_index(device); | 
|  | // Default to the current stream | 
|  | const auto stream = (streams.empty() || !streams[i]) | 
|  | ? at::cuda::getCurrentCUDAStream(device).stream() | 
|  | : streams[i]->stream(); | 
|  | TORCH_CHECK( | 
|  | static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max), | 
|  | "Broadcast tensor has ", | 
|  | numel, | 
|  | " elements, which exceeds the " | 
|  | "maximum NCCL supports (", | 
|  | count_max, | 
|  | ")"); | 
|  | NCCL_CHECK(ncclBcast( | 
|  | tensors[i].data_ptr(), numel, data_type, 0, comms[i], stream)); | 
|  | } | 
|  | #else | 
|  | AT_ERROR("PyTorch built without NCCL support"); | 
|  | #endif | 
|  | } | 
|  |  | 
|  | void reduce( | 
|  | const std::vector<at::Tensor>& inputs, | 
|  | std::vector<at::Tensor>& outputs, | 
|  | int32_t root, | 
|  | int32_t op, | 
|  | const stream_list& streams, | 
|  | const comm_list& user_comms) { | 
|  | #ifdef USE_NCCL | 
|  | using namespace torch::cuda::nccl::detail; | 
|  | TORCH_CHECK( | 
|  | root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root"); | 
|  |  | 
|  | check_inputs(inputs, outputs, 1, 1); | 
|  | const auto len = inputs.size(); | 
|  |  | 
|  | ncclDataType_t data_type = get_data_type(inputs[0]); | 
|  |  | 
|  | const auto count = inputs[0].numel(); | 
|  | AutoNcclGroup nccl_group_guard; | 
|  | auto comms_ref = user_comms.empty() ? get_communicators(inputs) | 
|  | : ArrayRef<ncclComm_t>(user_comms); | 
|  |  | 
|  | at::cuda::OptionalCUDAGuard device_guard; | 
|  | for (size_t i = 0; i < len; i++) { | 
|  | int device = inputs[i].device().index(); | 
|  | device_guard.set_index(device); | 
|  | // Default to the current stream | 
|  | const auto stream = (streams.empty() || !streams[i]) | 
|  | ? at::cuda::getCurrentCUDAStream(device).stream() | 
|  | : streams[i]->stream(); | 
|  |  | 
|  | NCCL_CHECK(ncclReduce( | 
|  | inputs[i].data_ptr(), | 
|  | outputs[i].data_ptr(), | 
|  | count, | 
|  | data_type, | 
|  | (ncclRedOp_t)op, | 
|  | root, | 
|  | comms_ref[i], | 
|  | stream)); | 
|  | } | 
|  | #else | 
|  | AT_ERROR("PyTorch built without NCCL support"); | 
|  | #endif | 
|  | } | 
|  |  | 
|  | void reduce( | 
|  | std::vector<at::Tensor>& inputs, | 
|  | int32_t root, | 
|  | int32_t op, | 
|  | const stream_list& streams, | 
|  | const comm_list& user_comms) { | 
|  | reduce(inputs, /*outputs=*/inputs, root, op, streams, user_comms); | 
|  | } | 
|  | } // namespace nccl | 
|  | } // namespace cuda | 
|  | } // namespace torch |