blob: 8255bceebd6cf2c0d2d4c2b98e0396c1020a3d6b [file] [log] [blame]
#pragma once
#include <vector>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
namespace c10d {
constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin";
// ProcessGroupRoundRobin implements simple load balancing.
//
// It is constructed with multiple processes groups. Each call is dispatched to
// one of the specified process groups in a round robin fashion. Each process
// group instance must have the same rank and size.
//
// All functions of the class are expected to be called in the same order
// across all processes in the process group. This is the only way that we
// can guarantee to match up the same calls among all processes.
//
class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup {
public:
explicit ProcessGroupRoundRobin(
int rank,
int size,
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups);
~ProcessGroupRoundRobin() override;
const std::string getBackendName() const override {
return std::string(ROUND_ROBIN_BACKEND_NAME);
}
c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<Work> scatter(
std::vector<at::Tensor>& outputs,
std::vector<std::vector<at::Tensor>>& inputs,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputs,
std::vector<std::vector<at::Tensor>>& inputs,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
private:
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_;
std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_;
// Returns the next ProcessGroup to use.
const c10::intrusive_ptr<ProcessGroup>& next();
};
} // namespace c10d