blob: a8f9b454b003793d17a68d5456ed0a938fd09e84 [file] [log] [blame]
#pragma once
#ifdef USE_C10D_GLOO
#include <c10d/ProcessGroup.hpp>
#include <c10d/ProcessGroupGloo.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
namespace c10d {
class TORCH_API ProcessGroupWrapper : public ProcessGroup {
public:
explicit ProcessGroupWrapper(
c10::intrusive_ptr<ProcessGroup> pg,
c10::intrusive_ptr<ProcessGroupGloo> glooPg);
const std::string getBackendName() const override;
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& data,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
// This function is deprecated and will be moved out of ProcessGroup to comms:
// * do not add dependencies on this function,
// * do not implement it in your ProcessGroup, implement _allgather_base
// instead.
c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false)
override;
// Agrees on an initial sequence number for the whole group by having rank 0
// create it and broadcast it to other ranks using the store. Only implemented
// for GLOO and NCCL backends currently.
// dont implement this
void setSequenceNumberForGroup() override;
// Retrieves the current sequence number for the whole group, which should be
// in sync. If the returned number is not consistent across the group, it
// may indicate that there is some sort of collective desynchronization.
uint64_t getSequenceNumberForGroup() override; // just call underlying
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
private:
// Underlying process group that actual application collectives will be
// dispatched to
c10::intrusive_ptr<ProcessGroup> pg_;
// Gloo process group responsible for internal coordination such as monitored
// barrier, sequence number checking, collective fingerprint collecting.
c10::intrusive_ptr<ProcessGroupGloo> glooPg_;
// Conducts several checks to ensure that the underlying collective is well
// formed with the goal of notifying the user about incorrect collective use
// in the application.
void runCollectiveChecks(
OpType op_type,
const std::vector<at::Tensor>& tensors) const;
};
} // namespace c10d
#endif // USE_C10D_GLOO