|  | #pragma once | 
|  |  | 
|  | #include <condition_variable> | 
|  | #include <memory> | 
|  | #include <mutex> | 
|  | #include <stdexcept> | 
|  | #include <unordered_map> | 
|  | #include <vector> | 
|  |  | 
|  | #include <ATen/ATen.h> | 
|  | #include <c10/macros/Macros.h> | 
|  |  | 
|  | #include <c10d/Types.hpp> | 
|  | #include <c10d/Utils.hpp> | 
|  | #include <c10d/sequence_num.hpp> | 
|  |  | 
|  | // ************************************************************************* | 
|  | // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN | 
|  | // versions 1.7 and 1.8. | 
|  | // PLEASE DO NOT ADD ANY DEPENDENCIES. | 
|  | // SEE RFC: https://github.com/pytorch/pytorch/issues/39662 | 
|  | // ************************************************************************* | 
|  |  | 
|  | constexpr auto kNoTimeout = std::chrono::milliseconds(0); | 
|  | constexpr auto kProcessGroupDefaultTimeout = | 
|  | std::chrono::milliseconds(30 * 60 * 1000); | 
|  |  | 
|  | namespace c10d { | 
|  |  | 
|  | constexpr const char * const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY"; | 
|  |  | 
|  | enum class OpType : std::uint8_t { | 
|  | BROADCAST = 0, | 
|  | ALLREDUCE = 1, | 
|  | ALLREDUCE_COALESCED = 2, | 
|  | REDUCE = 3, | 
|  | ALLGATHER = 4, | 
|  | _ALLGATHER_BASE = 5, | 
|  | ALLGATHER_COALESCED = 6, | 
|  | GATHER = 7, | 
|  | SCATTER = 8, | 
|  | REDUCE_SCATTER = 9, | 
|  | ALLTOALL_BASE = 10, | 
|  | ALLTOALL = 11, | 
|  | SEND = 12, | 
|  | RECV = 13, | 
|  | RECVANYSOURCE = 14, | 
|  | BARRIER = 15, | 
|  | _REDUCE_SCATTER_BASE = 16, | 
|  | UNKNOWN = 100, | 
|  | }; | 
|  |  | 
|  | // Converts OpType to human readable string. | 
|  | TORCH_API std::string opTypeToString(OpType opType); | 
|  |  | 
|  | // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) | 
|  | TORCH_API bool isP2POp(OpType opType); | 
|  |  | 
|  | // ProcessGroup is a base class that captures collective and point to | 
|  | // point communication in a fixed set of processes. | 
|  | // | 
|  | // The functions specified in the class below describe the API alone; | 
|  | // implementations are provided in subclasses. | 
|  | // | 
|  | // Every function that performs I/O is executed asynchronously by a | 
|  | // thread pool owned by the ProcessGroup (by default). They return an | 
|  | // object that can be used to wait for completion or error. | 
|  | // | 
|  | // The ProcessGroup can instantiate subgroups with fewer or an equal | 
|  | // number of members. Implementations must take care that multiple | 
|  | // process groups can be used in parallel and synchronize accordingly. | 
|  | // | 
|  | // The ProcessGroup assumes a fixed set of processes. If the set | 
|  | // changes, existing instances must be destructed and instantiation | 
|  | // and initialization must start from scratch. For members of the | 
|  | // process group to find each other (referred to as rendezvous from | 
|  | // hereon) | 
|  | // | 
|  | class TORCH_API ProcessGroup : public torch::CustomClassHolder { | 
|  | public: | 
|  | // Please do not use ProcessGroup::Work API, it is going away, to be | 
|  | // replaced by ivalue::Future. | 
|  | // Python binding for this class might change, please do not assume | 
|  | // this will be bound using pybind. | 
|  | class TORCH_API Work : public torch::CustomClassHolder { | 
|  | public: | 
|  | Work( | 
|  | int rank = -1, | 
|  | OpType opType = OpType::UNKNOWN, | 
|  | const char* profilingTitle = nullptr, | 
|  | const c10::optional<std::vector<at::Tensor>>& inputTensors = | 
|  | c10::nullopt); | 
|  |  | 
|  | virtual ~Work(); | 
|  |  | 
|  | // Checks if request has completed. Non-blocking operation. | 
|  | virtual bool isCompleted(); | 
|  |  | 
|  | // Returns if the work completed successfully. | 
|  | // If false, the exception function can be called to get details. | 
|  | virtual bool isSuccess() const; | 
|  |  | 
|  | // Returns exception if isSuccess() returned false. | 
|  | virtual std::exception_ptr exception() const; | 
|  |  | 
|  | // Returns source rank if this objects represents a recv-from-any. | 
|  | virtual int sourceRank() const; | 
|  |  | 
|  | // Returns result tensors, if applicable. | 
|  | // If work is not supposed to have result, we return empty list. | 
|  | virtual std::vector<at::Tensor> result(); | 
|  |  | 
|  | // Ensures that operations on the output tensors that are invoked | 
|  | // after this function returns are correctly sequenced after the | 
|  | // asynchronous completion of this work. | 
|  | // | 
|  | // For CUDA tensors, it inserts stream synchronization such that | 
|  | // the streams of the caller wait for completion of the | 
|  | // asynchronous operations on the destination tensors. | 
|  | // | 
|  | // For CPU tensors, it is currently a nop. | 
|  | // | 
|  | // This function should only be used if the caller polls for | 
|  | // completion through the `isCompleted` function, it has returned | 
|  | // true, and the `isSuccess` function also has returned true. | 
|  | // | 
|  | virtual void synchronize(); | 
|  |  | 
|  | // Waits until request completes. Blocking operation. | 
|  | // Throws if the work completed with an exception. | 
|  | // Returns false if the work is aborted. | 
|  | // Otherwise, it always returns true, indicating the work is completed. | 
|  | // | 
|  | // Functionally equivalent to: | 
|  | // | 
|  | //   while (!isCompleted()) { /* nop */ } | 
|  | //   auto success = isSuccess(); | 
|  | //   if (!success) { std::rethrow_exception(exception()); } | 
|  | //   return success; | 
|  | // | 
|  | virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout); | 
|  |  | 
|  | virtual void abort(); | 
|  |  | 
|  | // Returns a Future object that will be associated with the completion of | 
|  | // work. Only NCCL backend is currently supported. | 
|  | virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture(); | 
|  |  | 
|  | OpType retrieveOpType(); | 
|  |  | 
|  | protected: | 
|  | // Completes the work object and optionally sets the exception in a | 
|  | // thread-safe manner. Notifies all waiting condition variables as well. | 
|  | void finish(std::exception_ptr exception = nullptr); | 
|  |  | 
|  | // Similar to finish, but throws an exception if one is already set or | 
|  | // provided by the user. | 
|  | void finishAndThrow(std::exception_ptr exception); | 
|  |  | 
|  | mutable std::mutex mutex_; | 
|  | std::condition_variable cv_; | 
|  | bool completed_ = false; | 
|  | std::exception_ptr exception_; | 
|  |  | 
|  | // Current rank of the node. | 
|  | const int rank_; | 
|  |  | 
|  | // Operation type that this work object refers to. | 
|  | OpType opType_; | 
|  |  | 
|  | // When profiling, the callback to record end of operation event. This | 
|  | // callback needs to be called when collective operation is complete. | 
|  | std::function<void()> recordFunctionEndCallback_; | 
|  | }; | 
|  |  | 
|  | // ProcessGroup Options is a base struct that defines the basic options | 
|  | // when constructing a ProcessGroup. Each ProcessGroup subclass should | 
|  | // extend this struct and define its options if it wants to provide more | 
|  | // config options (beyond basic ones defined here) to end user. | 
|  | struct TORCH_API Options : torch::CustomClassHolder { | 
|  | explicit Options( | 
|  | std::string backend, | 
|  | std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) | 
|  | : timeout(timeout), backend(backend) {} | 
|  | virtual ~Options() = default; | 
|  |  | 
|  | std::chrono::milliseconds timeout; | 
|  |  | 
|  | // backend name | 
|  | const std::string backend; | 
|  | }; | 
|  |  | 
|  | explicit ProcessGroup(int rank, int size); | 
|  | virtual ~ProcessGroup(); | 
|  |  | 
|  | int getRank() const { | 
|  | return rank_; | 
|  | } | 
|  |  | 
|  | int getSize() const { | 
|  | return size_; | 
|  | } | 
|  |  | 
|  | virtual const std::string getBackendName() const { | 
|  | return "undefined"; | 
|  | } | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast( | 
|  | std::vector<at::Tensor>& data, | 
|  | const BroadcastOptions& opts = BroadcastOptions()) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce( | 
|  | std::vector<at::Tensor>& data, | 
|  | const AllreduceOptions& opts = AllreduceOptions()) = 0; | 
|  |  | 
|  | // This will be moved out of ProcessGroup, do not add dependencies on this | 
|  | // function. | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> reduce( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | const ReduceOptions& opts = ReduceOptions()) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> allgather( | 
|  | std::vector<std::vector<at::Tensor>>& outputTensors, | 
|  | std::vector<at::Tensor>& inputTensors, | 
|  | const AllgatherOptions& opts = AllgatherOptions()) = 0; | 
|  |  | 
|  | // Gathers a single tensor inputBuffer into a single buffer outputBuffer that | 
|  | // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. | 
|  | // For implementers of ProcessGroup API and advanced users only. | 
|  | // Note: this function will be deprecated in near future. | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> _allgather_base( | 
|  | at::Tensor& outputBuffer, | 
|  | at::Tensor& inputBuffer, | 
|  | const AllgatherOptions& opts = AllgatherOptions()) = 0; | 
|  |  | 
|  | // This function is deprecated and will be moved out of ProcessGroup to comms: | 
|  | // * do not add dependencies on this function, | 
|  | // * do not implement it in your ProcessGroup, implement _allgather_base | 
|  | //   instead. | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced( | 
|  | std::vector<std::vector<at::Tensor>>& outputTensorLists, | 
|  | std::vector<at::Tensor>& inputTensors, | 
|  | const AllgatherOptions& opts = AllgatherOptions()); | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> gather( | 
|  | std::vector<std::vector<at::Tensor>>& outputTensors, | 
|  | std::vector<at::Tensor>& inputTensors, | 
|  | const GatherOptions& opts = GatherOptions()) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> scatter( | 
|  | std::vector<at::Tensor>& outputTensors, | 
|  | std::vector<std::vector<at::Tensor>>& inputTensors, | 
|  | const ScatterOptions& opts = ScatterOptions()) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter( | 
|  | std::vector<at::Tensor>& outputTensors, | 
|  | std::vector<std::vector<at::Tensor>>& inputTensors, | 
|  | const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base( | 
|  | at::Tensor&, | 
|  | at::Tensor&, | 
|  | const ReduceScatterOptions& opts = ReduceScatterOptions()) { | 
|  | TORCH_CHECK(false, "ProcessGroup does not support reduce_scatter_base"); | 
|  | } | 
|  |  | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base( | 
|  | at::Tensor& outputTensor, | 
|  | at::Tensor& inputTensor, | 
|  | std::vector<int64_t>& outputSplitSizes, | 
|  | std::vector<int64_t>& inputSplitSizes, | 
|  | const AllToAllOptions& opts = AllToAllOptions()) { | 
|  | TORCH_CHECK(false, "ProcessGroup does not support alltoall"); | 
|  | } | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall( | 
|  | std::vector<at::Tensor>& outputTensors, | 
|  | std::vector<at::Tensor>& inputTensors, | 
|  | const AllToAllOptions& opts = AllToAllOptions()) { | 
|  | TORCH_CHECK(false, "ProcessGroup does not support alltoall"); | 
|  | } | 
|  |  | 
|  | virtual void monitoredBarrier( | 
|  | const BarrierOptions& /* unused */, bool /* unused */ = false ) { | 
|  | auto backendName = getBackendName(); | 
|  | TORCH_CHECK(false, | 
|  | c10::str("ProcessGroup ", | 
|  | backendName, | 
|  | " does not support monitoredBarrier, only GLOO supports monitored barrier.") | 
|  | ); | 
|  | } | 
|  |  | 
|  | // Agrees on an initial sequence number for the whole group by having rank 0 | 
|  | // create it and broadcast it to other ranks using the store. Only implemented | 
|  | // for GLOO and NCCL backends currently. | 
|  | virtual void setSequenceNumberForGroup() { | 
|  | auto backendName = getBackendName(); | 
|  | TORCH_CHECK(false, | 
|  | c10::str("ProcessGroup ", | 
|  | backendName, | 
|  | " does not yet support sequence numbers.") | 
|  | ); | 
|  | } | 
|  |  | 
|  | // Retrieves the current sequence number for the whole group, which should be | 
|  | // in sync. If the returned number is not consistent across the group, it | 
|  | // may indicate that there is some sort of collective desynchronization. | 
|  | virtual uint64_t getSequenceNumberForGroup() { | 
|  | auto backendName = getBackendName(); | 
|  | TORCH_CHECK(false, | 
|  | c10::str("ProcessGroup ", | 
|  | backendName, | 
|  | " does not yet support sequence numbers.") | 
|  | ); | 
|  | } | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> send( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | int dstRank, | 
|  | int tag) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> recv( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | int srcRank, | 
|  | int tag) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | int tag) = 0; | 
|  |  | 
|  | virtual c10::intrusive_ptr<ProcessGroup::Work> barrier( | 
|  | const BarrierOptions& opts = BarrierOptions()) = 0; | 
|  |  | 
|  | protected: | 
|  | const int rank_; | 
|  | const int size_; | 
|  | // Optional sequence number structure for matching collectives. | 
|  | c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt; | 
|  | // Debug level setting. It is parsed once when ProcessGroup is constructed and | 
|  | // remains the same across use of this process group. | 
|  | DistributedDebugLevel dist_debug_level_; | 
|  | }; | 
|  |  | 
|  | } // namespace c10d |