| #include <c10d/ProcessGroupMPI.hpp> |
| |
| #include <limits> |
| #include <map> |
| |
| #include <c10/core/DeviceGuard.h> |
| |
| #if defined(OPEN_MPI) && OPEN_MPI |
| #include <mpi-ext.h> // Needed for CUDA-aware check |
| #endif |
| |
| namespace c10d { |
| |
| #define MPI_CHECK(cmd) \ |
| do { \ |
| int mpiStatus = cmd; \ |
| if (mpiStatus != MPI_SUCCESS) { \ |
| std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \ |
| std::to_string(__LINE__) + \ |
| ", with error code: " + std::to_string(mpiStatus); \ |
| throw std::runtime_error(err); \ |
| } \ |
| } while (0) |
| |
| namespace { |
| |
| // Op mapping |
| std::map<ReduceOp, MPI_Op> mpiOp = { |
| {ReduceOp::MIN, MPI_MIN}, |
| {ReduceOp::MAX, MPI_MAX}, |
| {ReduceOp::SUM, MPI_SUM}, |
| {ReduceOp::PRODUCT, MPI_PROD}, |
| }; |
| // Type mapping |
| std::map<at::ScalarType, MPI_Datatype> mpiDatatype = { |
| {at::kByte, MPI_UNSIGNED_CHAR}, |
| {at::kChar, MPI_CHAR}, |
| {at::kDouble, MPI_DOUBLE}, |
| {at::kFloat, MPI_FLOAT}, |
| {at::kInt, MPI_INT}, |
| {at::kLong, MPI_LONG}, |
| {at::kShort, MPI_SHORT}, |
| }; |
| |
| // Checking CUDA-aware MPI support, currently we only support CUDA aware |
| // MPI ops through Open MPI |
| bool cudaAwareMpiCheck() { |
| // Run time check |
| #if defined(MPIX_CUDA_AWARE_SUPPORT) |
| if (MPIX_Query_cuda_support() == 1) { |
| return true; |
| } else { |
| return false; |
| } |
| #else // !defined(MPIX_CUDA_AWARE_SUPPORT) |
| return false; |
| #endif // MPIX_CUDA_AWARE_SUPPORT |
| } |
| |
| // Checking the input tensor's validity |
| void checkSingleTensorHelper(const at::Tensor& tensor) { |
| if (!tensor.is_contiguous()) { |
| throw std::runtime_error("input tensor has to be contiguous"); |
| } |
| if (tensor.is_sparse()) { |
| throw std::runtime_error("input tensor has to be dense"); |
| } |
| if (tensor.is_cuda() && !cudaAwareMpiCheck()) { |
| throw std::runtime_error( |
| "CUDA tensor detected and the MPI used doesn't " |
| "have CUDA-aware MPI support"); |
| } |
| } |
| |
| void checkSingleTensor(const std::vector<at::Tensor>& tensors) { |
| if (tensors.size() != 1) { |
| throw std::runtime_error( |
| "MPI process group does not support multi-GPU collectives"); |
| } |
| checkSingleTensorHelper(tensors[0]); |
| } |
| |
| void checkSameSizeAndType( |
| const at::Tensor& tensor, |
| const std::vector<at::Tensor>& tensors) { |
| for (size_t i = 0; i < tensors.size(); ++i) { |
| if ((tensors[i].numel() != tensor.numel()) || |
| (tensors[i].type() != tensor.type())) { |
| throw std::runtime_error("Tensors are not equal in size or data type"); |
| } |
| checkSingleTensorHelper(tensors[i]); |
| } |
| } |
| |
| void checkSplitSizes( |
| const std::vector<int64_t>& split_sizes, |
| const at::Tensor& tensor, |
| int group_size) { |
| if (split_sizes.size() == 0) { |
| TORCH_CHECK( |
| tensor.size(0) % group_size == 0, |
| "Tensor's dim 0 does not divide equally across group size"); |
| } else { |
| TORCH_CHECK( |
| split_sizes.size() == group_size, |
| "Number of tensor splits not equal to group size"); |
| int sum = std::accumulate(split_sizes.begin(), split_sizes.end(), 0); |
| TORCH_CHECK( |
| sum == tensor.size(0), "Split sizes doesn't match total dim 0 size"); |
| } |
| } |
| |
| int64_t computeLengthsAndOffsets( |
| const std::vector<int64_t>& split_sizes, |
| const at::Tensor& tensor, |
| std::vector<int>* lengths, |
| std::vector<int>* offsets) { |
| int64_t group_size = lengths->size(); |
| bool equal_splits = false; |
| int64_t dim0_size = tensor.size(0); |
| int64_t row_size = (dim0_size ? tensor.numel() / dim0_size : 1); |
| int64_t split_size = 0; |
| int64_t offset = 0; |
| |
| if (split_sizes.size() == 0) { |
| equal_splits = true; |
| split_size = tensor.size(0) / group_size; |
| } |
| for (int i = 0; i < group_size; i++) { |
| int64_t length = row_size * (equal_splits ? split_size : split_sizes[i]); |
| TORCH_INTERNAL_ASSERT( |
| length <= std::numeric_limits<int>::max() && |
| offset <= std::numeric_limits<int>::max(), |
| "Length or offset larger than INT_MAX not supported"); |
| (*lengths)[i] = length; |
| (*offsets)[i] = offset; |
| offset += length; |
| } |
| return offset; |
| } |
| |
| int64_t computeLengthsAndOffsets( |
| const std::vector<at::Tensor>& tensors, |
| std::vector<int>* lengths, |
| std::vector<int>* offsets) { |
| int64_t group_size = lengths->size(); |
| int64_t offset = 0; |
| for (int i = 0; i < group_size; i++) { |
| int64_t length = tensors[i].numel(); |
| TORCH_INTERNAL_ASSERT( |
| length <= std::numeric_limits<int>::max() && |
| offset <= std::numeric_limits<int>::max(), |
| "Length or offset larger than INT_MAX not supported"); |
| (*lengths)[i] = length; |
| (*offsets)[i] = offset; |
| offset += length; |
| } |
| return offset; |
| } |
| |
| } // namespace |
| |
| ProcessGroupMPI::AsyncWork::AsyncWork(at::Tensor tensor, MPI_Request request) |
| : tensor_(std::move(tensor)), request_(request) { |
| memset(&status_, 0, sizeof(status_)); |
| } |
| |
| ProcessGroupMPI::AsyncWork::~AsyncWork() { |
| if (request_ != MPI_REQUEST_NULL) { |
| std::cerr |
| << "Attempted destruction of AsyncWork before work has completed, " |
| << "terminating the program." << std::endl; |
| std::terminate(); |
| } |
| } |
| |
| bool ProcessGroupMPI::AsyncWork::isCompleted() { |
| if (request_ == MPI_REQUEST_NULL) { |
| return true; |
| } |
| |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| int flag = 0; |
| MPI_CHECK(MPI_Test(&request_, &flag, &status_)); |
| if (request_ != MPI_REQUEST_NULL) { |
| return false; |
| } |
| |
| // request_ == MPI_REQUEST_NULL; the work has completed |
| // Populate exception if request was not successful |
| if (status_.MPI_ERROR != MPI_SUCCESS) { |
| populateException(); |
| } |
| |
| return true; |
| } |
| |
| bool ProcessGroupMPI::AsyncWork::isSuccess() const { |
| if (request_ != MPI_REQUEST_NULL) { |
| throw std::runtime_error( |
| "Invalid call to AsyncWork::isSuccess before work has completed"); |
| } |
| |
| return status_.MPI_ERROR == MPI_SUCCESS; |
| } |
| |
| int ProcessGroupMPI::AsyncWork::sourceRank() const { |
| return status_.MPI_SOURCE; |
| } |
| |
| bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) { |
| if (request_ == MPI_REQUEST_NULL) { |
| return true; |
| } |
| |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Wait(&request_, &status_)); |
| auto ok = (status_.MPI_ERROR == MPI_SUCCESS); |
| if (!ok) { |
| populateException(); |
| std::rethrow_exception(exception_); |
| } |
| // Always return true, because abort API is not implemented. |
| return true; |
| } |
| |
| void ProcessGroupMPI::AsyncWork::abort() { |
| TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.") |
| } |
| |
| void ProcessGroupMPI::AsyncWork::populateException() { |
| std::array<char, MPI_MAX_ERROR_STRING> buf; |
| int len = buf.size(); |
| MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len)); |
| exception_ = |
| std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len))); |
| } |
| |
| // Static global states |
| int ProcessGroupMPI::mpiThreadSupport_ = 0; |
| std::mutex ProcessGroupMPI::pgGlobalMutex_; |
| // We only want to initialize once |
| std::once_flag ProcessGroupMPI::onceFlagInitMPI; |
| |
| void ProcessGroupMPI::mpiExit() { |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Finalize()); |
| } |
| |
| void ProcessGroupMPI::initMPIOnce() { |
| // Initialize MPI environment |
| std::call_once(onceFlagInitMPI, []() { |
| MPI_CHECK(MPI_Init_thread( |
| nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_)); |
| if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) { |
| throw std::runtime_error( |
| "Used MPI implementation doesn't have the " |
| "minimum level of threading support: " |
| "MPI_THREAD_SERIALIZED. This is required by " |
| "c10d package"); |
| } |
| if (std::atexit(ProcessGroupMPI::mpiExit)) { |
| throw std::runtime_error("Fail to register the MPI exit handler"); |
| } |
| }); |
| } |
| |
| std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI( |
| std::vector<int> ranks) { |
| // Once initialization |
| initMPIOnce(); |
| |
| MPI_Comm groupComm = MPI_COMM_WORLD; |
| int rank = -1; |
| int size = -1; |
| |
| { |
| std::lock_guard<std::mutex> globalLock(pgGlobalMutex_); |
| |
| // If no ranks are specified, assume we're creating the root group |
| if (!ranks.empty()) { |
| MPI_Group worldGroup; |
| MPI_Group ranksGroup; |
| MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); |
| MPI_CHECK( |
| MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup)); |
| MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)); |
| MPI_CHECK(MPI_Group_free(&worldGroup)); |
| MPI_CHECK(MPI_Group_free(&ranksGroup)); |
| } |
| |
| // Fetch rank and world size for this group (MPI_COMM_WORLD or new) |
| if (groupComm != MPI_COMM_NULL) { |
| MPI_CHECK(MPI_Comm_rank(groupComm, &rank)); |
| MPI_CHECK(MPI_Comm_size(groupComm, &size)); |
| |
| if (rank < 0 || size < 0) { |
| throw std::runtime_error("Failed to get the world_size / rank"); |
| } |
| } |
| } |
| |
| // If this process is not part of the group, we don't construct a |
| // process group instance. This is in line with the semantics of the |
| // other process group types. |
| if (groupComm == MPI_COMM_NULL) { |
| return std::shared_ptr<ProcessGroupMPI>(); |
| } |
| |
| return std::make_shared<ProcessGroupMPI>(rank, size, groupComm); |
| } |
| |
| ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm) |
| : ProcessGroup(rank, size), stop_(false), pgComm_(pgComm) { |
| if (pgComm_ == MPI_COMM_NULL) { |
| throw std::runtime_error("pgComm_ must not be MPI_COMM_NULL"); |
| } |
| |
| // Start the worker thread accepting MPI calls |
| workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this); |
| } |
| |
| ProcessGroupMPI::~ProcessGroupMPI() { |
| destroy(); |
| } |
| |
| void ProcessGroupMPI::destroy() { |
| std::unique_lock<std::mutex> lock(pgMutex_); |
| queueConsumeCV_.wait(lock, [&] { return queue_.empty(); }); |
| |
| // Queue is empty, signal stop |
| stop_ = true; |
| |
| // Release lock to allow threads to terminate |
| lock.unlock(); |
| queueProduceCV_.notify_all(); |
| |
| // Join the single worker thread |
| workerThread_.join(); |
| } |
| |
| void ProcessGroupMPI::abort() { |
| destroy(); |
| MPI_Abort(pgComm_, EXIT_FAILURE); |
| } |
| |
| void ProcessGroupMPI::runLoop() { |
| std::unique_lock<std::mutex> lock(pgMutex_); |
| |
| while (!stop_) { |
| if (queue_.empty()) { |
| queueProduceCV_.wait(lock); |
| continue; |
| } |
| |
| auto workTuple = std::move(queue_.front()); |
| |
| queue_.pop_front(); |
| |
| auto& workEntry = std::get<0>(workTuple); |
| auto& work = std::get<1>(workTuple); |
| |
| lock.unlock(); |
| queueConsumeCV_.notify_one(); |
| |
| try { |
| workEntry->run(workEntry); |
| work->finish(); |
| } catch (...) { |
| work->finish(std::current_exception()); |
| } |
| |
| lock.lock(); |
| } |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::enqueue( |
| std::unique_ptr<WorkEntry> entry) { |
| auto work = std::make_shared<WorkMPI>(); |
| std::unique_lock<std::mutex> lock(pgMutex_); |
| queue_.push_back(std::make_tuple(std::move(entry), work)); |
| lock.unlock(); |
| queueProduceCV_.notify_one(); |
| return work; |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::broadcast( |
| std::vector<at::Tensor>& tensors, |
| const BroadcastOptions& opts) { |
| checkSingleTensor(tensors); |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Bcast( |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| opts.rootRank, |
| pgComm_)); |
| }; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&tensors, nullptr, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceOptions& opts) { |
| checkSingleTensor(tensors); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Allreduce( |
| MPI_IN_PLACE, |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| mpiOp.at(opts.reduceOp), |
| pgComm_)); |
| }; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&tensors, nullptr, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allreduce_coalesced( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceCoalescedOptions& opts) { |
| throw std::runtime_error( |
| "allreduce_coalesced is currently not supported with MPI"); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::reduce( |
| std::vector<at::Tensor>& tensors, |
| const ReduceOptions& opts) { |
| checkSingleTensor(tensors); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| auto dataPtr = (entry->src)[0].data_ptr(); |
| void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr; |
| void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr; |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Reduce( |
| sendbuf, |
| recvbuf, |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| mpiOp.at(opts.reduceOp), |
| opts.rootRank, |
| pgComm_)); |
| }; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&tensors, nullptr, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& opts) { |
| checkSingleTensor(inputTensors); |
| if (outputTensors.size() != 1) { |
| throw std::runtime_error( |
| "MPI process group only supports a single " |
| "tensor op"); |
| } |
| if (static_cast<size_t>(size_) != outputTensors[0].size()) { |
| throw std::runtime_error( |
| "All gather: number of output tensors should equal " |
| "to the world size"); |
| } |
| |
| checkSameSizeAndType(inputTensors[0], outputTensors[0]); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| std::vector<at::Tensor>& outputDataVec = entry->dst; |
| auto flatOutputTensor = newLikeFlat(outputDataVec); |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Allgather( |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| flatOutputTensor.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| pgComm_)); |
| |
| for (size_t i = 0; i < outputDataVec.size(); ++i) { |
| outputDataVec[i].copy_(flatOutputTensor[i]); |
| } |
| }; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&inputTensors, &outputTensors[0], std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather_coalesced( |
| std::vector<std::vector<at::Tensor>>& /* unused */, |
| std::vector<at::Tensor>& /* unused */, |
| const AllgatherOptions& /* unused */) { |
| throw std::runtime_error( |
| "ProcessGroupMPI does not support allgather_coalesced"); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const GatherOptions& opts) { |
| checkSingleTensor(inputTensors); |
| |
| if (rank_ != opts.rootRank) { |
| if (outputTensors.size() > 0) { |
| throw std::runtime_error( |
| "Gather: number of output tensors should be 0 " |
| "for non-root"); |
| } |
| } else { |
| if (outputTensors.size() != 1) { |
| throw std::runtime_error("Gather: multi-GPU collective is not supported"); |
| } |
| if (static_cast<size_t>(size_) != outputTensors[0].size()) { |
| throw std::runtime_error( |
| "Gather: number of output tensors should equal " |
| "to the world size"); |
| } |
| checkSameSizeAndType(inputTensors[0], outputTensors[0]); |
| } |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->src)[0]; |
| void* recvbuf = nullptr; |
| at::Tensor flatOutputTensor; |
| |
| if (rank_ == opts.rootRank) { |
| flatOutputTensor = newLikeFlat(entry->dst); |
| recvbuf = flatOutputTensor.data_ptr(); |
| } |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Gather( |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| recvbuf, |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| opts.rootRank, |
| pgComm_)); |
| |
| if (rank_ == opts.rootRank) { |
| std::vector<at::Tensor>& outputDataVec = entry->dst; |
| // copy the flattened output tensors to the outputs |
| for (size_t i = 0; i < outputDataVec.size(); ++i) { |
| outputDataVec.at(i).copy_(flatOutputTensor[i]); |
| } |
| } |
| }; |
| |
| if (rank_ == opts.rootRank) { |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&inputTensors, &outputTensors[0], std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } else { |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&inputTensors, nullptr, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ScatterOptions& opts) { |
| checkSingleTensor(outputTensors); |
| |
| if (rank_ != opts.rootRank) { |
| if (inputTensors.size() > 0) { |
| throw std::runtime_error( |
| "Scatter: number of input tensors should be 0 " |
| "for non-root"); |
| } |
| } else { |
| if (inputTensors.size() != 1) { |
| throw std::runtime_error( |
| "Scatter: multi-GPU collective is not supported"); |
| } |
| if (static_cast<size_t>(size_) != inputTensors[0].size()) { |
| throw std::runtime_error( |
| "Scatter: number of input tensors should equal " |
| "to the world size"); |
| } |
| checkSameSizeAndType(outputTensors[0], inputTensors[0]); |
| } |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto data = (entry->dst)[0]; |
| void* sendbuf = nullptr; |
| at::Tensor flatInputTensor; |
| |
| if (rank_ == opts.rootRank) { |
| std::vector<at::Tensor>& inputDataVec = entry->src; |
| flatInputTensor = newLikeFlat(inputDataVec); |
| sendbuf = flatInputTensor.data_ptr(); |
| |
| // copy the input tensors to the flatten large send buffer |
| for (size_t i = 0; i < inputDataVec.size(); ++i) { |
| flatInputTensor[i].copy_(inputDataVec.at(i)); |
| } |
| } |
| |
| c10::DeviceGuard guard(data.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Scatter( |
| sendbuf, |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| data.data_ptr(), |
| data.numel(), |
| mpiDatatype.at(data.scalar_type()), |
| opts.rootRank, |
| pgComm_)); |
| }; |
| |
| if (rank_ == opts.rootRank) { |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&inputTensors[0], &outputTensors, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } else { |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(nullptr, &outputTensors, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::reduce_scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ReduceScatterOptions& opts) { |
| throw std::runtime_error("ProcessGroupMPI does not support reduce_scatter"); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::alltoall_base( |
| at::Tensor& outputTensor, |
| at::Tensor& inputTensor, |
| std::vector<int64_t>& outputSplitSizes, |
| std::vector<int64_t>& inputSplitSizes, |
| const AllToAllOptions& opts) { |
| checkSingleTensorHelper(inputTensor); |
| checkSingleTensorHelper(outputTensor); |
| |
| if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { |
| // We can use alltoall |
| TORCH_CHECK( |
| outputTensor.numel() == inputTensor.numel() && |
| outputTensor.type() == inputTensor.type(), |
| "Tensors are not equal in size or data type"); |
| TORCH_CHECK( |
| outputTensor.size(0) % size_ == 0, |
| "Tensor's dim 0 does not divide equally across group size"); |
| |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| auto srcdata = (entry->src)[0]; |
| auto dstdata = (entry->dst)[0]; |
| c10::DeviceGuard guard(srcdata.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Alltoall( |
| srcdata.data_ptr(), |
| srcdata.numel() / size_, |
| mpiDatatype.at(srcdata.scalar_type()), |
| dstdata.data_ptr(), |
| dstdata.numel() / size_, |
| mpiDatatype.at(dstdata.scalar_type()), |
| pgComm_)); |
| }; |
| std::vector<at::Tensor> inputTensors = {inputTensor}; |
| std::vector<at::Tensor> outputTensors = {outputTensor}; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&inputTensors, &outputTensors, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } else { |
| // Need alltoallv |
| checkSplitSizes(inputSplitSizes, inputTensor, size_); |
| checkSplitSizes(outputSplitSizes, outputTensor, size_); |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this, inputSplitSizes, outputSplitSizes]( |
| std::unique_ptr<WorkEntry>& entry) { |
| auto srcdata = (entry->src)[0]; |
| auto dstdata = (entry->dst)[0]; |
| std::vector<int> send_lengths(size_); |
| std::vector<int> recv_lengths(size_); |
| std::vector<int> send_offsets(size_); |
| std::vector<int> recv_offsets(size_); |
| computeLengthsAndOffsets( |
| inputSplitSizes, srcdata, &send_lengths, &send_offsets); |
| computeLengthsAndOffsets( |
| outputSplitSizes, dstdata, &recv_lengths, &recv_offsets); |
| c10::DeviceGuard guard(srcdata.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Alltoallv( |
| srcdata.data_ptr(), |
| send_lengths.data(), |
| send_offsets.data(), |
| mpiDatatype.at(srcdata.scalar_type()), |
| dstdata.data_ptr(), |
| recv_lengths.data(), |
| recv_offsets.data(), |
| mpiDatatype.at(dstdata.scalar_type()), |
| pgComm_)); |
| }; |
| std::vector<at::Tensor> inputTensors = {inputTensor}; |
| std::vector<at::Tensor> outputTensors = {outputTensor}; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&inputTensors, &outputTensors, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| } |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::alltoall( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllToAllOptions& opts) { |
| TORCH_CHECK( |
| inputTensors.size() == size_, |
| "Number of input tensors are not equal to group size"); |
| TORCH_CHECK( |
| outputTensors.size() == size_, |
| "Number of output tensors are not equal to group size"); |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [opts, this](std::unique_ptr<WorkEntry>& entry) { |
| std::vector<int> send_lengths(size_); |
| std::vector<int> recv_lengths(size_); |
| std::vector<int> send_offsets(size_); |
| std::vector<int> recv_offsets(size_); |
| auto srcdata = entry->src; |
| auto dstdata = entry->dst; |
| int64_t src_len = |
| computeLengthsAndOffsets(srcdata, &send_lengths, &send_offsets); |
| int64_t dst_len = |
| computeLengthsAndOffsets(dstdata, &recv_lengths, &recv_offsets); |
| std::vector<int64_t> send_lengthsL( |
| send_lengths.begin(), send_lengths.end()); |
| std::vector<int64_t> recv_lengthsL( |
| recv_lengths.begin(), recv_lengths.end()); |
| at::Tensor srcFlatData = at::empty({src_len}, srcdata[0].options()); |
| at::Tensor dstFlatData = at::empty({dst_len}, dstdata[0].options()); |
| auto srcFlatDataSplits = |
| srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0); |
| for (int i = 0; i < size_; i++) { |
| srcFlatDataSplits[i].copy_(srcdata[i].view({-1})); |
| } |
| c10::DeviceGuard guard1(srcdata[0].device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Alltoallv( |
| srcFlatData.data_ptr(), |
| send_lengths.data(), |
| send_offsets.data(), |
| mpiDatatype.at(srcdata[0].scalar_type()), |
| dstFlatData.data_ptr(), |
| recv_lengths.data(), |
| recv_offsets.data(), |
| mpiDatatype.at(dstdata[0].scalar_type()), |
| pgComm_)); |
| |
| auto dstFlatDataSplits = |
| dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0); |
| for (int i = 0; i < size_; i++) { |
| dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]); |
| } |
| }; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(&inputTensors, &outputTensors, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::send( |
| std::vector<at::Tensor>& tensors, |
| int dstRank, |
| int tag) { |
| checkSingleTensor(tensors); |
| |
| auto& tensor = tensors[0]; |
| MPI_Request request = MPI_REQUEST_NULL; |
| |
| { |
| c10::DeviceGuard guard(tensor.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Isend( |
| tensor.data_ptr(), |
| tensor.numel(), |
| mpiDatatype.at(tensor.scalar_type()), |
| dstRank, |
| tag, |
| pgComm_, |
| &request)); |
| } |
| |
| return std::make_shared<AsyncWork>(tensor, request); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recv( |
| std::vector<at::Tensor>& tensors, |
| int srcRank, |
| int tag) { |
| checkSingleTensor(tensors); |
| |
| auto& tensor = tensors[0]; |
| MPI_Request request = MPI_REQUEST_NULL; |
| |
| { |
| c10::DeviceGuard guard(tensor.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Irecv( |
| tensor.data_ptr(), |
| tensor.numel(), |
| mpiDatatype.at(tensor.scalar_type()), |
| srcRank, |
| tag, |
| pgComm_, |
| &request)); |
| } |
| |
| return std::make_shared<AsyncWork>(tensor, request); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::recvAnysource( |
| std::vector<at::Tensor>& tensors, |
| int tag) { |
| checkSingleTensor(tensors); |
| |
| auto& tensor = tensors[0]; |
| MPI_Request request = MPI_REQUEST_NULL; |
| |
| { |
| c10::DeviceGuard guard(tensor.device()); |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Irecv( |
| tensor.data_ptr(), |
| tensor.numel(), |
| mpiDatatype.at(tensor.scalar_type()), |
| MPI_ANY_SOURCE, |
| tag, |
| pgComm_, |
| &request)); |
| } |
| |
| return std::make_shared<AsyncWork>(tensor, request); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::barrier( |
| const BarrierOptions& opts) { |
| std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
| [this](std::unique_ptr<WorkEntry>& entry) { |
| std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
| MPI_CHECK(MPI_Barrier(pgComm_)); |
| }; |
| auto entry = std::unique_ptr<WorkEntry>( |
| new WorkEntry(nullptr, nullptr, std::move(runFunc))); |
| return enqueue(std::move(entry)); |
| } |
| |
| std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather_base( |
| at::Tensor& /*unused */, |
| at::Tensor& /*unused */, |
| const AllgatherOptions& /*unused */) { |
| throw std::runtime_error( |
| "no support for allgather_base in MPI process group"); |
| } |
| |
| } // namespace c10d |