| #pragma once |
| |
| #include <condition_variable> |
| #include <memory> |
| #include <mutex> |
| #include <stdexcept> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include <ATen/ATen.h> |
| |
| #include <c10d/Types.hpp> |
| #include <c10d/Utils.hpp> |
| #include <c10d/sequence_num.hpp> |
| |
| // ************************************************************************* |
| // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN |
| // versions 1.7 and 1.8. |
| // PLEASE DO NOT ADD ANY DEPENDENCIES. |
| // SEE RFC: https://github.com/pytorch/pytorch/issues/39662 |
| // ************************************************************************* |
| |
| constexpr auto kNoTimeout = std::chrono::milliseconds(0); |
| constexpr auto kProcessGroupDefaultTimeout = |
| std::chrono::milliseconds(30 * 60 * 1000); |
| |
| namespace c10d { |
| |
| constexpr const char * const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY"; |
| |
| enum class OpType : std::uint8_t { |
| BROADCAST = 0, |
| ALLREDUCE = 1, |
| ALLREDUCE_COALESCED = 2, |
| REDUCE = 3, |
| ALLGATHER = 4, |
| _ALLGATHER_BASE = 5, |
| ALLGATHER_COALESCED = 6, |
| GATHER = 7, |
| SCATTER = 8, |
| REDUCE_SCATTER = 9, |
| ALLTOALL_BASE = 10, |
| ALLTOALL = 11, |
| SEND = 12, |
| RECV = 13, |
| RECVANYSOURCE = 14, |
| BARRIER = 15, |
| _REDUCE_SCATTER_BASE = 16, |
| UNKNOWN = 100, |
| }; |
| |
| // Converts OpType to human readable string. |
| std::string opTypeToString(OpType opType); |
| |
| // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) |
| bool isP2POp(OpType opType); |
| |
| // ProcessGroup is a base class that captures collective and point to |
| // point communication in a fixed set of processes. |
| // |
| // The functions specified in the class below describe the API alone; |
| // implementations are provided in subclasses. |
| // |
| // Every function that performs I/O is executed asynchronously by a |
| // thread pool owned by the ProcessGroup (by default). They return an |
| // object that can be used to wait for completion or error. |
| // |
| // The ProcessGroup can instantiate subgroups with fewer or an equal |
| // number of members. Implementations must take care that multiple |
| // process groups can be used in parallel and synchronize accordingly. |
| // |
| // The ProcessGroup assumes a fixed set of processes. If the set |
| // changes, existing instances must be destructed and instantiation |
| // and initialization must start from scratch. For members of the |
| // process group to find each other (referred to as rendezvous from |
| // hereon) |
| // |
| class ProcessGroup : public torch::CustomClassHolder { |
| public: |
| // Please do not use ProcessGroup::Work API, it is going away, to be |
| // replaced by ivalue::Future. |
| // Python binding for this class might change, please do not assume |
| // this will be bound using pybind. |
| class Work : public torch::CustomClassHolder { |
| public: |
| Work( |
| int rank = -1, |
| OpType opType = OpType::UNKNOWN, |
| const char* profilingTitle = nullptr, |
| const c10::optional<std::vector<at::Tensor>>& inputTensors = |
| c10::nullopt); |
| |
| virtual ~Work(); |
| |
| // Checks if request has completed. Non-blocking operation. |
| virtual bool isCompleted(); |
| |
| // Returns if the work completed successfully. |
| // If false, the exception function can be called to get details. |
| virtual bool isSuccess() const; |
| |
| // Returns exception if isSuccess() returned false. |
| virtual std::exception_ptr exception() const; |
| |
| // Returns source rank if this objects represents a recv-from-any. |
| virtual int sourceRank() const; |
| |
| // Returns result tensors, if applicable. |
| // If work is not supposed to have result, we return empty list. |
| virtual std::vector<at::Tensor> result(); |
| |
| // Ensures that operations on the output tensors that are invoked |
| // after this function returns are correctly sequenced after the |
| // asynchronous completion of this work. |
| // |
| // For CUDA tensors, it inserts stream synchronization such that |
| // the streams of the caller wait for completion of the |
| // asynchronous operations on the destination tensors. |
| // |
| // For CPU tensors, it is currently a nop. |
| // |
| // This function should only be used if the caller polls for |
| // completion through the `isCompleted` function, it has returned |
| // true, and the `isSuccess` function also has returned true. |
| // |
| virtual void synchronize(); |
| |
| // Waits until request completes. Blocking operation. |
| // Throws if the work completed with an exception. |
| // Returns false if the work is aborted. |
| // Otherwise, it always returns true, indicating the work is completed. |
| // |
| // Functionally equivalent to: |
| // |
| // while (!isCompleted()) { /* nop */ } |
| // auto success = isSuccess(); |
| // if (!success) { std::rethrow_exception(exception()); } |
| // return success; |
| // |
| virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout); |
| |
| virtual void abort(); |
| |
| // Returns a Future object that will be associated with the completion of |
| // work. Only NCCL backend is currently supported. |
| virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture(); |
| |
| OpType retrieveOpType(); |
| |
| protected: |
| // Completes the work object and optionally sets the exception in a |
| // thread-safe manner. Notifies all waiting condition variables as well. |
| void finish(std::exception_ptr exception = nullptr); |
| |
| // Similar to finish, but throws an exception if one is already set or |
| // provided by the user. |
| void finishAndThrow(std::exception_ptr exception); |
| |
| mutable std::mutex mutex_; |
| std::condition_variable cv_; |
| bool completed_ = false; |
| std::exception_ptr exception_; |
| |
| // Current rank of the node. |
| const int rank_; |
| |
| // Operation type that this work object refers to. |
| OpType opType_; |
| |
| // When profiling, the callback to record end of operation event. This |
| // callback needs to be called when collective operation is complete. |
| std::function<void()> recordFunctionEndCallback_; |
| }; |
| |
| // ProcessGroup Options is a base struct that defines the basic options |
| // when constructing a ProcessGroup. Each ProcessGroup subclass should |
| // extend this struct and define its options if it wants to provide more |
| // config options (beyond basic ones defined here) to end user. |
| struct Options : torch::CustomClassHolder { |
| explicit Options( |
| std::string backend, |
| std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) |
| : timeout(timeout), backend(backend) {} |
| virtual ~Options() = default; |
| |
| std::chrono::milliseconds timeout; |
| |
| // backend name |
| const std::string backend; |
| }; |
| |
| explicit ProcessGroup(int rank, int size); |
| virtual ~ProcessGroup(); |
| |
| int getRank() const { |
| return rank_; |
| } |
| |
| int getSize() const { |
| return size_; |
| } |
| |
| virtual const std::string getBackendName() const { |
| return "undefined"; |
| } |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast( |
| std::vector<at::Tensor>& data, |
| const BroadcastOptions& opts = BroadcastOptions()) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce( |
| std::vector<at::Tensor>& data, |
| const AllreduceOptions& opts = AllreduceOptions()) = 0; |
| |
| // This will be moved out of ProcessGroup, do not add dependencies on this |
| // function. |
| virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced( |
| std::vector<at::Tensor>& tensors, |
| const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> reduce( |
| std::vector<at::Tensor>& tensors, |
| const ReduceOptions& opts = ReduceOptions()) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> allgather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& opts = AllgatherOptions()) = 0; |
| |
| // Gathers a single tensor inputBuffer into a single buffer outputBuffer that |
| // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. |
| // For implementers of ProcessGroup API and advanced users only. |
| // Note: this function will be deprecated in near future. |
| virtual c10::intrusive_ptr<ProcessGroup::Work> _allgather_base( |
| at::Tensor& outputBuffer, |
| at::Tensor& inputBuffer, |
| const AllgatherOptions& opts = AllgatherOptions()) = 0; |
| |
| // This function is deprecated and will be moved out of ProcessGroup to comms: |
| // * do not add dependencies on this function, |
| // * do not implement it in your ProcessGroup, implement _allgather_base |
| // instead. |
| virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced( |
| std::vector<std::vector<at::Tensor>>& outputTensorLists, |
| std::vector<at::Tensor>& inputTensors, |
| const AllgatherOptions& opts = AllgatherOptions()); |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> gather( |
| std::vector<std::vector<at::Tensor>>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const GatherOptions& opts = GatherOptions()) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ScatterOptions& opts = ScatterOptions()) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<std::vector<at::Tensor>>& inputTensors, |
| const ReduceScatterOptions& opts = ReduceScatterOptions()) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base( |
| at::Tensor&, |
| at::Tensor&, |
| const ReduceScatterOptions& opts = ReduceScatterOptions()) { |
| throw std::runtime_error("ProcessGroup does not support reduce_scatter_base"); |
| } |
| |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base( |
| at::Tensor& outputTensor, |
| at::Tensor& inputTensor, |
| std::vector<int64_t>& outputSplitSizes, |
| std::vector<int64_t>& inputSplitSizes, |
| const AllToAllOptions& opts = AllToAllOptions()) { |
| throw std::runtime_error("ProcessGroup does not support alltoall"); |
| } |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall( |
| std::vector<at::Tensor>& outputTensors, |
| std::vector<at::Tensor>& inputTensors, |
| const AllToAllOptions& opts = AllToAllOptions()) { |
| throw std::runtime_error("ProcessGroup does not support alltoall"); |
| } |
| |
| virtual void monitoredBarrier( |
| const BarrierOptions& /* unused */, bool /* unused */ = false ) { |
| auto backendName = getBackendName(); |
| throw std::runtime_error( |
| c10::str("ProcessGroup ", |
| backendName, |
| " does not support monitoredBarrier, only GLOO supports monitored barrier.") |
| ); |
| } |
| |
| // Agrees on an initial sequence number for the whole group by having rank 0 |
| // create it and broadcast it to other ranks using the store. Only implemented |
| // for GLOO and NCCL backends currently. |
| virtual void setSequenceNumberForGroup() { |
| auto backendName = getBackendName(); |
| throw std::runtime_error( |
| c10::str("ProcessGroup ", |
| backendName, |
| " does not yet support sequence numbers.") |
| ); |
| } |
| |
| // Retrieves the current sequence number for the whole group, which should be |
| // in sync. If the returned number is not consistent across the group, it |
| // may indicate that there is some sort of collective desynchronization. |
| virtual uint64_t getSequenceNumberForGroup() { |
| auto backendName = getBackendName(); |
| throw std::runtime_error( |
| c10::str("ProcessGroup ", |
| backendName, |
| " does not yet support sequence numbers.") |
| ); |
| } |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> send( |
| std::vector<at::Tensor>& tensors, |
| int dstRank, |
| int tag) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> recv( |
| std::vector<at::Tensor>& tensors, |
| int srcRank, |
| int tag) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource( |
| std::vector<at::Tensor>& tensors, |
| int tag) = 0; |
| |
| virtual c10::intrusive_ptr<ProcessGroup::Work> barrier( |
| const BarrierOptions& opts = BarrierOptions()) = 0; |
| |
| protected: |
| const int rank_; |
| const int size_; |
| // Optional sequence number structure for matching collectives. |
| c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt; |
| // Debug level setting. It is parsed once when ProcessGroup is constructed and |
| // remains the same across use of this process group. |
| DistributedDebugLevel dist_debug_level_; |
| }; |
| |
| } // namespace c10d |