| #pragma once |
| |
| #include <vector> |
| |
| #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
| |
| namespace c10d { |
| |
| constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin"; |
| |
| // ProcessGroupRoundRobin implements simple load balancing. |
| // |
| // It is constructed with multiple processes groups. Each call is dispatched to |
| // one of the specified process groups in a round robin fashion. Each process |
| // group instance must have the same rank and size. |
| // |
| // All functions of the class are expected to be called in the same order |
| // across all processes in the process group. This is the only way that we |
| // can guarantee to match up the same calls among all processes. |
| // |
| class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup { |
| public: |
| explicit ProcessGroupRoundRobin( |
| int rank, |
| int size, |
| std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups); |
| |
| ~ProcessGroupRoundRobin() override; |
| |
| const std::string getBackendName() const override { |
| return std::string(ROUND_ROBIN_BACKEND_NAME); |
| } |
| |
| c10::intrusive_ptr<Work> broadcast( |
| std::vector<at::Tensor>& tensors, |
| const BroadcastOptions& opts = BroadcastOptions()) override; |
| |
| c10::intrusive_ptr<Work> allreduce( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceOptions& opts = AllreduceOptions()) override; |
| |
| c10::intrusive_ptr<Work> allreduce_coalesced( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceCoalescedOptions& opts = |
| AllreduceCoalescedOptions()) override; |
| |
| c10::intrusive_ptr<Work> reduce( |
| std::vector<at::Tensor>& tensors, |
| const ReduceOptions& opts = ReduceOptions()) override; |
| |
| c10::intrusive_ptr<Work> allgather( |
| std::vector<std::vector<at::Tensor>>& outputs, |
| std::vector<at::Tensor>& inputs, |
| const AllgatherOptions& opts = AllgatherOptions()) override; |
| |
| c10::intrusive_ptr<Work> _allgather_base( |
| at::Tensor& outputBuffer, |
| at::Tensor& inputBuffer, |
| const AllgatherOptions& opts = AllgatherOptions()) override; |
| |
| c10::intrusive_ptr<Work> allgather_coalesced( |
| std::vector<std::vector<at::Tensor>>& outputTensorLists, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& opts = AllgatherOptions()) override; |
| |
| c10::intrusive_ptr<Work> gather( |
| std::vector<std::vector<at::Tensor>>& outputs, |
| std::vector<at::Tensor>& inputs, |
| const GatherOptions& opts = GatherOptions()) override; |
| |
| c10::intrusive_ptr<Work> scatter( |
| std::vector<at::Tensor>& outputs, |
| std::vector<std::vector<at::Tensor>>& inputs, |
| const ScatterOptions& opts = ScatterOptions()) override; |
| |
| c10::intrusive_ptr<Work> reduce_scatter( |
| std::vector<at::Tensor>& outputs, |
| std::vector<std::vector<at::Tensor>>& inputs, |
| const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
| |
| c10::intrusive_ptr<Work> alltoall_base( |
| at::Tensor& outputTensor, |
| at::Tensor& inputTensor, |
| std::vector<int64_t>& outputSplitSizes, |
| std::vector<int64_t>& inputSplitSizes, |
| const AllToAllOptions& opts = AllToAllOptions()) override; |
| |
| c10::intrusive_ptr<Work> send( |
| std::vector<at::Tensor>& tensors, |
| int dstRank, |
| int tag) override; |
| |
| c10::intrusive_ptr<Work> recv( |
| std::vector<at::Tensor>& tensors, |
| int srcRank, |
| int tag) override; |
| |
| c10::intrusive_ptr<Work> recvAnysource( |
| std::vector<at::Tensor>& tensors, |
| int tag) override; |
| |
| c10::intrusive_ptr<Work> barrier( |
| const BarrierOptions& opts = BarrierOptions()) override; |
| |
| private: |
| std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_; |
| std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_; |
| |
| // Returns the next ProcessGroup to use. |
| const c10::intrusive_ptr<ProcessGroup>& next(); |
| }; |
| |
| } // namespace c10d |