blob: 420c78ef028ada77e80df623da2dde7c3e211635 [file] [log] [blame]
#pragma once
#include <condition_variable>
#include <deque>
#include <exception>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
#include <mpi.h>
namespace c10d {
// WorkEntry is the state associated with a single MPI run instance.
// It include the source Tensor list and destination Tensor list, as well as
// The actual run function that will operate either on src or dst or both.
struct WorkEntry {
explicit WorkEntry(
std::vector<at::Tensor>* srcPtr,
std::vector<at::Tensor>* dstPtr,
std::function<void(std::unique_ptr<WorkEntry>&)> run)
: run(run) {
if (srcPtr) {
src = *srcPtr;
}
if (dstPtr) {
dst = *dstPtr;
}
}
// Not copyable
WorkEntry(const WorkEntry&) = delete;
// Not copy assignable
WorkEntry& operator=(const WorkEntry&) = delete;
// For input and output tensors (in-place), we will always use src
std::vector<at::Tensor> src;
std::vector<at::Tensor> dst;
// src rank returned, for recv only
int* srcRank = nullptr;
std::function<void(std::unique_ptr<WorkEntry>&)> run;
};
// ProcessGroupMPI implements MPI bindings for c10d.
//
// All functions on this class are expected to be called in the same
// order across processes in the group. This is the only way that we
// can guarantee to match up the same calls across processes.
//
// All MPI functions provided by this class is asynchronously scheduled on a
// Worker thread. Therefore, ProcessGroupMPI requires the MPI implementation
// that is used to have a minimum thread support value of MPI_THREAD_SERIALIZED.
// That is, The process may be multi-threaded, and multiple threads may make
// MPI calls, but only one at a time: MPI calls are not made concurrently from
// two distinct threads (all MPI calls are serialized). However, with
// MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process
// group. In other words, no more than 1 process group can be created globally.
//
// If you would like to use multiple ProcessGroupMPI, it requres your MPI
// implemenation to have a thread support value of MPI_THREAD_MULTIPLE, that is,
// multiple threads may call MPI, with no restriction.
//
// Also note that ProcessGroupMPI only supports a single Tensor operation. In
// other words, the size of the input Tensor vector should always be 1.
//
// CUDA tensor can be supported if the MPI used is CUDA-aware MPI, and
// ProcessGroupMPI will automatically detect this support.
class ProcessGroupMPI : public ProcessGroup {
public:
class WorkMPI : public ProcessGroup::Work {
protected:
friend class ProcessGroupMPI;
};
class AsyncWork : public ProcessGroup::Work {
public:
AsyncWork(at::Tensor tensor, MPI_Request request);
virtual ~AsyncWork();
bool isCompleted() override;
bool isSuccess() const override;
int sourceRank() const override;
bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
void abort() override;
protected:
void populateException();
at::Tensor tensor_;
MPI_Request request_;
MPI_Status status_;
};
// Constructor will spawn up the worker thread loop
explicit ProcessGroupMPI(int rank, int size, MPI_Comm pgComm);
virtual ~ProcessGroupMPI();
// Abort the MPI program, needs to be called when exception is detected
void abort();
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather_base(
at::Tensor& outputbuffer,
at::Tensor& inputbuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensor,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
// Creating a new ProcessGroupMPI, will initiialize MPI if not initialized
static c10::intrusive_ptr<ProcessGroupMPI> createProcessGroupMPI(
std::vector<int> ranks = {});
protected:
using WorkType =
std::tuple<std::unique_ptr<WorkEntry>, c10::intrusive_ptr<WorkMPI>>;
// Worker thread loop
void runLoop();
// Helper function that is called by the destructor
void destroy();
c10::intrusive_ptr<ProcessGroup::Work> enqueue(std::unique_ptr<WorkEntry> entry);
bool stop_;
std::mutex pgMutex_;
std::thread workerThread_;
std::deque<WorkType> queue_;
std::condition_variable queueProduceCV_;
std::condition_variable queueConsumeCV_;
// Global states
static void initMPIOnce();
static void mpiExit();
static std::once_flag onceFlagInitMPI;
static std::mutex pgGlobalMutex_;
static int mpiThreadSupport_;
MPI_Comm pgComm_;
};
} // namespace c10d