blob: 7605553459a36fb4a0f4e26195642a056e426ccf [file] [log] [blame]
#pragma once
#include <condition_variable>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <unordered_map>
#include <vector>
#include <ATen/ATen.h>
#include <c10d/Types.hpp>
namespace c10d {
// ProcessGroup is a base class that captures collective and point to
// point communication in a fixed set of processes.
//
// The functions specified in the class below describe the API alone;
// implementations are provided in subclasses.
//
// Every function that performs I/O is executed asynchronously by a
// thread pool owned by the ProcessGroup (by default). They return an
// object that can be used to wait for completion or error.
//
// The ProcessGroup can instantiate subgroups with fewer or an equal
// number of members. Implementations must take care that multiple
// process groups can be used in parallel and synchronize accordingly.
//
// The ProcessGroup assumes a fixed set of processes. If the set
// changes, existing instances must be destructed and instantiation
// and initialization must start from scratch. For members of the
// process group to find each other (referred to as rendezvous from
// hereon)
//
class ProcessGroup {
public:
class Work {
public:
virtual ~Work();
// Checks if request has completed. Non-blocking operation.
virtual bool isCompleted();
// Returns if the work completed successfully.
// If false, the exception function can be called to get details.
virtual bool isSuccess() const;
// Returns exception if isSuccess() returned false.
virtual std::exception_ptr exception() const;
// Returns source rank if this objects represents a recv-from-any.
virtual int sourceRank() const;
// Returns result tensors, if applicable.
virtual std::vector<at::Tensor> result() const;
// Ensures that operations on the output tensors that are invoked
// after this function returns are correctly sequenced after the
// asynchronous completion of this work.
//
// For CUDA tensors, it inserts stream synchronization such that
// the streams of the caller wait for completion of the
// asynchronous operations on the destination tensors.
//
// For CPU tensors, it is currently a nop.
//
// This function should only be used if the caller polls for
// completion through the `isCompleted` function, it has returned
// true, and the `isSuccess` function also has returned true.
//
virtual void synchronize();
// Waits until request completes. Blocking operation.
// Throws if the work completed with an exception.
//
// Functionally equivalent to:
//
// while (!isCompleted()) { /* nop */ }
// auto success = isSuccess();
// if (!success) { std::rethrow_exception(exception()); }
// return success;
//
virtual void wait();
protected:
void finish(std::exception_ptr exception = nullptr);
mutable std::mutex mutex_;
std::condition_variable cv_;
bool completed_ = false;
std::exception_ptr exception_;
};
explicit ProcessGroup(int rank, int size);
virtual ~ProcessGroup();
int getRank() const {
return rank_;
}
int getSize() const {
return size_;
}
virtual std::shared_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) = 0;
virtual std::shared_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& data,
const AllreduceOptions& opts = AllreduceOptions()) = 0;
virtual std::shared_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) = 0;
virtual std::shared_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) = 0;
virtual std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) = 0;
virtual std::shared_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) = 0;
virtual std::shared_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0;
virtual std::shared_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) = 0;
virtual std::shared_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) = 0;
virtual std::shared_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) = 0;
virtual std::shared_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) = 0;
protected:
const int rank_;
const int size_;
};
} // namespace c10d