|  | #include <c10/util/Exception.h> | 
|  | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> | 
|  |  | 
|  | #ifdef USE_C10D_GLOO | 
|  |  | 
|  | #include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp> | 
|  | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> | 
|  | #include <chrono> | 
|  | #include <exception> | 
|  | #include <ratio> | 
|  | #include <tuple> | 
|  |  | 
|  | #ifdef _WIN32 | 
|  | #include <gloo/common/win.h> | 
|  | #include <winsock2.h> | 
|  | #include <ws2tcpip.h> | 
|  | #else | 
|  | #include <netdb.h> | 
|  | #include <sys/socket.h> | 
|  | #include <unistd.h> | 
|  | #endif | 
|  | #include <sys/types.h> | 
|  |  | 
|  | #include <type_traits> | 
|  |  | 
|  | #include <gloo/allgather.h> | 
|  | #include <gloo/allgatherv.h> | 
|  | #include <gloo/allreduce.h> | 
|  | #include <gloo/alltoall.h> | 
|  | #include <gloo/alltoallv.h> | 
|  | #include <gloo/barrier.h> | 
|  | #include <gloo/broadcast.h> | 
|  | #include <gloo/gather.h> | 
|  | #include <gloo/reduce.h> | 
|  | #include <gloo/scatter.h> | 
|  |  | 
|  | #include <ATen/SparseTensorUtils.h> | 
|  | #include <ATen/ThreadLocalState.h> | 
|  |  | 
|  | #include <c10/util/StringUtil.h> | 
|  | #include <c10/util/intrusive_ptr.h> | 
|  | #include <c10/util/irange.h> | 
|  | #include <gloo/config.h> | 
|  | #include <gloo/rendezvous/context.h> | 
|  | #include <gloo/rendezvous/prefix_store.h> | 
|  |  | 
|  | #ifdef _WIN32 | 
|  | #define GENERATE_ALL_TYPES(type, func, ...)      \ | 
|  | switch (type) {                                \ | 
|  | case ::at::ScalarType::Float:                \ | 
|  | func<float>(__VA_ARGS__);                  \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Double:               \ | 
|  | func<double>(__VA_ARGS__);                 \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Half:                 \ | 
|  | func<gloo::float16>(__VA_ARGS__);          \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Char:                 \ | 
|  | func<int8_t>(__VA_ARGS__);                 \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Byte:                 \ | 
|  | func<uint8_t>(__VA_ARGS__);                \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Int:                  \ | 
|  | func<int32_t>(__VA_ARGS__);                \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Long:                 \ | 
|  | func<int64_t>(__VA_ARGS__);                \ | 
|  | break;                                     \ | 
|  | default:                                     \ | 
|  | TORCH_CHECK(false, "Invalid scalar type"); \ | 
|  | } | 
|  |  | 
|  | #define HOST_NAME_MAX 256 | 
|  | #else | 
|  | #define GENERATE_ALL_TYPES(type, func, args...)  \ | 
|  | switch (type) {                                \ | 
|  | case ::at::ScalarType::Float:                \ | 
|  | func<float>(args);                         \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Double:               \ | 
|  | func<double>(args);                        \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Half:                 \ | 
|  | func<gloo::float16>(args);                 \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Char:                 \ | 
|  | func<int8_t>(args);                        \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Byte:                 \ | 
|  | func<uint8_t>(args);                       \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Int:                  \ | 
|  | func<int32_t>(args);                       \ | 
|  | break;                                     \ | 
|  | case ::at::ScalarType::Long:                 \ | 
|  | func<int64_t>(args);                       \ | 
|  | break;                                     \ | 
|  | default:                                     \ | 
|  | TORCH_CHECK(false, "Invalid scalar type"); \ | 
|  | } | 
|  | #endif | 
|  |  | 
|  | namespace c10d { | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | constexpr int kBytes = 8; | 
|  |  | 
|  | using steady_clock_time_point = | 
|  | std::chrono::time_point<std::chrono::steady_clock>; | 
|  |  | 
|  | std::chrono::milliseconds getRemainingTime( | 
|  | steady_clock_time_point startTime, | 
|  | const std::chrono::milliseconds& timeout, | 
|  | bool waitAllRanks) { | 
|  | if (waitAllRanks) { | 
|  | // See Note in monitoredBarrier | 
|  | return timeout; | 
|  | } | 
|  | auto elapsedTime = std::chrono::steady_clock::now() - startTime; | 
|  | auto remainingMillis = timeout - | 
|  | std::chrono::duration_cast<std::chrono::milliseconds>(elapsedTime); | 
|  |  | 
|  | // If no more remaining time, return -1 to indicate to caller. | 
|  | if (remainingMillis.count() <= 0) { | 
|  | return std::chrono::milliseconds(-1); | 
|  | } | 
|  |  | 
|  | return remainingMillis; | 
|  | } | 
|  |  | 
|  | // Emit a LOG(ERROR) and throws using TORCH_CHECK with the given messages. | 
|  | void logAndThrow( | 
|  | const std::string& logMessage, | 
|  | const std::string& errorMessage) { | 
|  | LOG(ERROR) << logMessage; | 
|  | TORCH_CHECK(false, errorMessage); | 
|  | } | 
|  |  | 
|  | // For monitoredBarrier, checks remaining time left to finish processing ranks | 
|  | // and throws error if timeout. | 
|  | void checkRemainingTime( | 
|  | const std::chrono::milliseconds& monitoredBarrierTimeout, | 
|  | const std::chrono::milliseconds& remainingTime, | 
|  | const std::vector<int>& processedRanks, | 
|  | int currentRank) { | 
|  | const std::string kNoRemainingTimeError = c10::str( | 
|  | "Rank ", | 
|  | currentRank, | 
|  | " timed out in monitoredBarrier after ", | 
|  | monitoredBarrierTimeout.count(), | 
|  | " ms."); | 
|  | if (remainingTime.count() < 0) { | 
|  | std::string rankInfo; | 
|  | if (processedRanks.size() > 0) { | 
|  | rankInfo = c10::str( | 
|  | "Successfully processed ranks: ", c10::Join(", ", processedRanks)); | 
|  | } else { | 
|  | rankInfo = "No ranks successfully processed in monitoredBarrier."; | 
|  | } | 
|  | auto error = c10::str(kNoRemainingTimeError, "\n", rankInfo); | 
|  | logAndThrow(error, error); | 
|  | } | 
|  | } | 
|  |  | 
|  | typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); | 
|  |  | 
|  | template < | 
|  | typename T, | 
|  | typename std::enable_if<!std::is_integral<T>::value, int>::type = 0> | 
|  | ReduceFunc toFunction(const ReduceOp& r) { | 
|  | switch (r) { | 
|  | case ReduceOp::SUM: | 
|  | return ReduceFunc(&::gloo::sum<T>); | 
|  | case ReduceOp::PRODUCT: | 
|  | return ReduceFunc(&::gloo::product<T>); | 
|  | case ReduceOp::MIN: | 
|  | return ReduceFunc(&::gloo::min<T>); | 
|  | case ReduceOp::MAX: | 
|  | return ReduceFunc(&::gloo::max<T>); | 
|  | case ReduceOp::BAND: | 
|  | TORCH_CHECK(false, "Cannot use ReduceOp.BAND with non-integral dtype"); | 
|  | break; | 
|  | case ReduceOp::BOR: | 
|  | TORCH_CHECK(false, "Cannot use ReduceOp.BOR with non-integral dtype"); | 
|  | break; | 
|  | case ReduceOp::BXOR: | 
|  | TORCH_CHECK(false, "Cannot use ReduceOp.BXOR with non-integral dtype"); | 
|  | break; | 
|  | case ReduceOp::AVG: | 
|  | TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo"); | 
|  | break; | 
|  | case ReduceOp::PREMUL_SUM: | 
|  | TORCH_CHECK(false, "Cannot use ReduceOp.PREMUL_SUM with Gloo"); | 
|  | break; | 
|  | case ReduceOp::UNUSED: | 
|  | break; | 
|  | } | 
|  |  | 
|  | TORCH_CHECK(false, "Unhandled ReduceOp"); | 
|  | } | 
|  |  | 
|  | // Bitwise AND with SFINAE guard for integral types. | 
|  | template < | 
|  | typename T, | 
|  | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> | 
|  | void band(void* c, const void* a, const void* b, size_t n) { | 
|  | auto tc = static_cast<T*>(c); | 
|  | auto ta = static_cast<const T*>(a); | 
|  | auto tb = static_cast<const T*>(b); | 
|  | for (const auto i : c10::irange(n)) { | 
|  | tc[i] = ta[i] & tb[i]; | 
|  | } | 
|  | } | 
|  |  | 
|  | // Bitwise OR with SFINAE guard for integral types. | 
|  | template < | 
|  | typename T, | 
|  | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> | 
|  | void bor(void* c, const void* a, const void* b, size_t n) { | 
|  | auto tc = static_cast<T*>(c); | 
|  | auto ta = static_cast<const T*>(a); | 
|  | auto tb = static_cast<const T*>(b); | 
|  | for (const auto i : c10::irange(n)) { | 
|  | tc[i] = ta[i] | tb[i]; | 
|  | } | 
|  | } | 
|  |  | 
|  | // Bitwise XOR with SFINAE guard for integral types. | 
|  | template < | 
|  | typename T, | 
|  | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> | 
|  | void bxor(void* c, const void* a, const void* b, size_t n) { | 
|  | auto tc = static_cast<T*>(c); | 
|  | auto ta = static_cast<const T*>(a); | 
|  | auto tb = static_cast<const T*>(b); | 
|  | for (const auto i : c10::irange(n)) { | 
|  | tc[i] = ta[i] ^ tb[i]; | 
|  | } | 
|  | } | 
|  |  | 
|  | template < | 
|  | typename T, | 
|  | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> | 
|  | ReduceFunc toFunction(const ReduceOp& r) { | 
|  | switch (r) { | 
|  | case ReduceOp::SUM: | 
|  | return ReduceFunc(&::gloo::sum<T>); | 
|  | case ReduceOp::PRODUCT: | 
|  | return ReduceFunc(&::gloo::product<T>); | 
|  | case ReduceOp::MIN: | 
|  | return ReduceFunc(&::gloo::min<T>); | 
|  | case ReduceOp::MAX: | 
|  | return ReduceFunc(&::gloo::max<T>); | 
|  | case ReduceOp::BAND: | 
|  | return ReduceFunc(&band<T>); | 
|  | case ReduceOp::BOR: | 
|  | return ReduceFunc(&bor<T>); | 
|  | case ReduceOp::BXOR: | 
|  | return ReduceFunc(&bxor<T>); | 
|  | case ReduceOp::AVG: | 
|  | TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo"); | 
|  | break; | 
|  | case ReduceOp::PREMUL_SUM: | 
|  | TORCH_CHECK(false, "Cannot use ReduceOp.PREMUL_SUM with Gloo"); | 
|  | break; | 
|  | case ReduceOp::UNUSED: | 
|  | break; | 
|  | } | 
|  |  | 
|  | TORCH_CHECK(false, "Unhandled ReduceOp"); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setInputs(O& opts, std::vector<at::Tensor>& tensors) { | 
|  | opts.setInputs(getDataPointers<T>(tensors), tensors[0].numel()); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setInput(O& opts, at::Tensor& tensor) { | 
|  | opts.setInput(getDataPointer<T>(tensor), tensor.numel()); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setInput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) { | 
|  | opts.setInput(getDataPointer<T>(tensor), counts); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setInput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) { | 
|  | opts.setInput(getDataPointer<T>(tensor), counts); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setOutputs(O& opts, std::vector<at::Tensor>& tensors) { | 
|  | opts.setOutputs(getDataPointers<T>(tensors), tensors[0].numel()); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setOutput(O& opts, at::Tensor& tensor) { | 
|  | opts.setOutput(getDataPointer<T>(tensor), tensor.numel()); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setOutput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) { | 
|  | opts.setOutput(getDataPointer<T>(tensor), counts); | 
|  | } | 
|  |  | 
|  | template <typename T, typename O> | 
|  | void setOutput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) { | 
|  | opts.setOutput(getDataPointer<T>(tensor), counts); | 
|  | } | 
|  |  | 
|  | at::Tensor pinnedLike(at::Tensor& tensor) { | 
|  | auto* allocator = at::detail::getCUDAHooks().getPinnedMemoryAllocator(); | 
|  | auto storage = c10::Storage( | 
|  | c10::Storage::use_byte_size_t(), | 
|  | at::detail::computeStorageNbytes( | 
|  | tensor.sizes(), tensor.strides(), tensor.dtype().itemsize()), | 
|  | allocator, | 
|  | /*resizable=*/false); | 
|  | return at::empty({0}, tensor.options().device(at::kCPU)) | 
|  | .set_(storage, 0, tensor.sizes(), tensor.strides()); | 
|  | } | 
|  |  | 
|  | // This function initializes a vector of CUDA streams, one for every | 
|  | // tensor in the input tensor vector, and ensures that these streams are | 
|  | // synchronized with the current default streams. This is needed so | 
|  | // that new work on the new streams is serialized w.r.t. all operations | 
|  | // on the tensors. | 
|  | void initializeStreamsEvents( | 
|  | const std::vector<at::Tensor>& tensors, | 
|  | std::vector<c10::Stream>& streams, | 
|  | std::vector<c10::Event>& events) { | 
|  | streams.reserve(tensors.size()); | 
|  | events.reserve(tensors.size()); | 
|  | for (const auto i : c10::irange(tensors.size())) { | 
|  | c10::Device device = tensors[i].device(); | 
|  | c10::impl::VirtualGuardImpl impl(device.type()); | 
|  | // Record event on current stream | 
|  | events.emplace_back(device.type()); | 
|  | events[i].record(impl.getStream(device)); | 
|  | // Get a non-default stream to execute asynchronous CUDA operations | 
|  | // on for this device. This ensures that the default stream used | 
|  | // by the caller is not occupied by c10d related operations. | 
|  | streams.push_back( | 
|  | impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); | 
|  | // Ensure the new stream is synchronized with the current stream. | 
|  | events[i].block(streams[i]); | 
|  |  | 
|  | // `tensors` are created on a different stream. Hence, they must record | 
|  | // new streams in this Work to prevent being freed before the Work finishes. | 
|  | if (tensors[i].is_sparse()) { | 
|  | if (tensors[i].is_coalesced()) { | 
|  | impl.recordDataPtrOnStream( | 
|  | tensors[i].indices().storage().data_ptr(), streams[i]); | 
|  | impl.recordDataPtrOnStream( | 
|  | tensors[i].values().storage().data_ptr(), streams[i]); | 
|  | } else { | 
|  | // We will need to coalesce first, which means new tensors will | 
|  | // be allocated on the streams we just allocated, and there | 
|  | // is no need to record them separately. | 
|  | } | 
|  | } else { | 
|  | impl.recordDataPtrOnStream(tensors[i].storage().data_ptr(), streams[i]); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | // This function initializes a vector of CUDA streams, one per device, | 
|  | // and ensures that these streams are synchronized with the current default | 
|  | // streams. It is assumed that the tensors in the nested tensor vectors are | 
|  | // on the same device. | 
|  | void initializeStreamsEvents( | 
|  | std::vector<std::vector<at::Tensor>>& tensors, | 
|  | std::vector<c10::Stream>& streams, | 
|  | std::vector<c10::Event>& events) { | 
|  | // Ensure that the tensors in the nested tensor vectors are on the same | 
|  | // device. | 
|  | for (const auto& tensorgroup : tensors) { | 
|  | const auto device_id = tensorgroup[0].device().index(); | 
|  | for (const auto& tensor : tensorgroup) { | 
|  | if (tensor.device().index() != device_id) { | 
|  | TORCH_CHECK( | 
|  | false, | 
|  | "tensors in the nested tensor vectors need to " | 
|  | "be on the same device"); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | streams.reserve(tensors.size()); | 
|  | events.reserve(tensors.size()); | 
|  | for (const auto i : c10::irange(tensors.size())) { | 
|  | c10::Device device = tensors[i][0].device(); | 
|  | c10::impl::VirtualGuardImpl impl(device.type()); | 
|  | // Record event on current stream | 
|  | events.emplace_back(device.type()); | 
|  | events[i].record(impl.getStream(device)); | 
|  | // Get a non-default stream to execute asynchronous CUDA operations | 
|  | // on for this output. This ensures that the default stream used | 
|  | // by the caller is not occupied by c10d related operations. | 
|  | streams.push_back( | 
|  | impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); | 
|  | // Ensure the new stream is synchronized with the current stream. | 
|  | events[i].block(streams[i]); | 
|  |  | 
|  | for (at::Tensor& tensor : tensors[i]) { | 
|  | // `tensors` are created on a different stream. Hence, they must record | 
|  | // new streams in this Work to prevent being freed before the Work | 
|  | // finishes. | 
|  | impl.recordDataPtrOnStream(tensor.storage().data_ptr(), streams[i]); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | const auto kLoopbackAddress = "127.0.0.1"; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | // static | 
|  | void ProcessGroupGloo::AsyncWork::execute(c10::intrusive_ptr<AsyncWork> work) { | 
|  | if (work->recordFunctionBeforeCallback_) { | 
|  | work->recordFunctionBeforeCallback_(); | 
|  | } | 
|  | try { | 
|  | work->run(); | 
|  | } catch (...) { | 
|  | work->finishWorkGlooError(std::current_exception()); | 
|  | return; | 
|  | } | 
|  |  | 
|  | // FIXME: We need to call it here since Future completion requires all | 
|  | // the work to be synchronized to CUDA. | 
|  | work->synchronize(); | 
|  | work->finishWorkGloo(); | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> ProcessGroupGloo::AsyncWork::result() { | 
|  | TORCH_CHECK( | 
|  | isCompleted(), | 
|  | "Work needs to be completed before calling result(). " | 
|  | "Should call wait() before result()."); | 
|  | TORCH_CHECK( | 
|  | outputTensors_.size() <= 1, | 
|  | "work result does not support list of lists, use .getFuture() and value()"); | 
|  | return outputTensors_.size() == 0 ? std::vector<at::Tensor>() | 
|  | : outputTensors_.at(0); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupGloo::AsyncWork:: | 
|  | getFuture() { | 
|  | return future_; | 
|  | } | 
|  |  | 
|  | namespace { | 
|  | c10::intrusive_ptr<c10::ivalue::Future> createFutureAsOutput( | 
|  | const std::vector<std::vector<at::Tensor>>& outputTensors) { | 
|  | if (outputTensors.size() > 1) { | 
|  | return c10::make_intrusive<c10::ivalue::Future>( | 
|  | c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); | 
|  | } | 
|  | return c10::make_intrusive<c10::ivalue::Future>( | 
|  | c10::ListType::create(c10::TensorType::get())); | 
|  | } | 
|  |  | 
|  | void returnFutureWithOutput( | 
|  | c10::intrusive_ptr<c10::ivalue::Future>& future, | 
|  | const std::vector<std::vector<at::Tensor>>& outputTensors) { | 
|  | if (outputTensors.size() == 0) { | 
|  | future->markCompleted(c10::IValue(std::vector<at::Tensor>())); | 
|  | return; | 
|  | } | 
|  | if (outputTensors.size() > 1) { | 
|  | future->markCompleted(c10::IValue(outputTensors)); | 
|  | return; | 
|  | } | 
|  | future->markCompleted(c10::IValue(outputTensors[0])); | 
|  | } | 
|  | } // namespace | 
|  |  | 
|  | inline void ProcessGroupGloo::AsyncWork::recordAsyncWorkProfilingInfo( | 
|  | const char* profilingTitle, | 
|  | const c10::optional<std::vector<at::Tensor>>& inputTensors) { | 
|  | auto recordingFunction = | 
|  | std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE); | 
|  | if (recordingFunction->isActive()) { | 
|  | std::function<void()> before_handler = | 
|  | [inputTensors, profilingTitle, recordingFunction]() { | 
|  | // The work will be started and completed by different threads. | 
|  | recordingFunction->_setAsync(); | 
|  | std::vector<c10::IValue> inputs; | 
|  | if (inputTensors) { | 
|  | inputs.reserve(inputTensors->size()); | 
|  | for (const auto& tensor : *inputTensors) { | 
|  | inputs.emplace_back(tensor); | 
|  | } | 
|  | } | 
|  | recordingFunction->before( | 
|  | profilingTitle, | 
|  | c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size())); | 
|  | }; | 
|  | recordFunctionBeforeCallback_ = at::wrapPropagateTLSState(before_handler); | 
|  | std::function<void()> end_handler = [recordingFunction]() { | 
|  | recordingFunction->end(); | 
|  | }; | 
|  | recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler); | 
|  | } | 
|  | } | 
|  |  | 
|  | ProcessGroupGloo::AsyncWork::AsyncWork( | 
|  | std::vector<std::vector<at::Tensor>> outputTensors, | 
|  | const char* profilingTitle, | 
|  | const c10::optional<std::vector<at::Tensor>>& inputTensors) | 
|  | // Profiler: Pass nullptr as profilingTitle to parent constructor to | 
|  | // replace default profiler implementation with async version that reports | 
|  | // correct timestamps for work that is asynchronously executed. | 
|  | : Work(-1, OpType::UNKNOWN, nullptr, inputTensors), | 
|  | outputTensors_(std::move(outputTensors)), | 
|  | future_(createFutureAsOutput(outputTensors)) { | 
|  | if (profilingTitle != nullptr) { | 
|  | recordAsyncWorkProfilingInfo(profilingTitle, inputTensors); | 
|  | } | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::AsyncWork::finishWorkGlooError(std::exception_ptr eptr) { | 
|  | future_->setError(eptr); | 
|  | finish(eptr); | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::AsyncWork::finishWorkGloo() { | 
|  | returnFutureWithOutput(future_, outputTensors_); | 
|  | finish(); | 
|  | } | 
|  |  | 
|  | ProcessGroupGloo::SendWork::SendWork( | 
|  | at::Tensor& tensor, | 
|  | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer) | 
|  | : Work( | 
|  | -1, | 
|  | OpType::SEND, | 
|  | "gloo:send", | 
|  | c10::optional<std::vector<at::Tensor>>({tensor})), | 
|  | tensor_(tensor), | 
|  | buffer_(std::move(buffer)) {} | 
|  |  | 
|  | bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) { | 
|  | bool sendCompleted = false; | 
|  | std::exception_ptr exception{nullptr}; | 
|  | try { | 
|  | if (timeout == kNoTimeout) { | 
|  | sendCompleted = buffer_->waitSend(); | 
|  | } else { | 
|  | sendCompleted = buffer_->waitSend(timeout); | 
|  | } | 
|  | } catch (...) { | 
|  | exception = std::current_exception(); | 
|  | } | 
|  |  | 
|  | // Completes the Work object and throws the exception. | 
|  | finishAndThrow(exception); | 
|  | return sendCompleted; | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::SendWork::abort() { | 
|  | buffer_->abortWaitSend(); | 
|  | } | 
|  |  | 
|  | ProcessGroupGloo::RecvWork::RecvWork( | 
|  | at::Tensor& tensor, | 
|  | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer, | 
|  | const char* profilingTitle) | 
|  | : Work( | 
|  | -1, | 
|  | OpType::UNKNOWN, | 
|  | profilingTitle, | 
|  | c10::optional<std::vector<at::Tensor>>({tensor})), | 
|  | tensor_(tensor), | 
|  | buffer_(std::move(buffer)), | 
|  | srcRank_(-1) {} | 
|  |  | 
|  | int ProcessGroupGloo::RecvWork::sourceRank() const { | 
|  | std::lock_guard<std::mutex> lock(mutex_); | 
|  | return srcRank_; | 
|  | } | 
|  |  | 
|  | bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) { | 
|  | bool recvCompleted = false; | 
|  | std::exception_ptr exception{nullptr}; | 
|  | try { | 
|  | if (timeout == kNoTimeout) { | 
|  | recvCompleted = buffer_->waitRecv(&srcRank_); | 
|  | } else { | 
|  | recvCompleted = buffer_->waitRecv(&srcRank_, timeout); | 
|  | } | 
|  | } catch (...) { | 
|  | exception = std::current_exception(); | 
|  | } | 
|  |  | 
|  | // Completes the Work object and throws the exception. | 
|  | finishAndThrow(exception); | 
|  | return recvCompleted; | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::RecvWork::abort() { | 
|  | buffer_->abortWaitRecv(); | 
|  | } | 
|  |  | 
|  | ProcessGroupGloo::Options::Options(std::chrono::milliseconds timeout) | 
|  | : Backend::Options(GLOO_BACKEND_NAME, timeout), threads(2) {} | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | void socketInitialize() { | 
|  | #ifdef _WIN32 | 
|  | ::gloo::init_winsock(); | 
|  | #endif | 
|  | } | 
|  |  | 
|  | // Gloo assumes that this machine's hostname can always be resolved | 
|  | // to an address. If it doesn't it throws a runtime error saying | 
|  | // that it can't be resolved. Instead of catching it, we choose | 
|  | // to proactively check if an address can be resolved, so we can | 
|  | // gracefully fall back to an alternative if it doesn't. | 
|  | bool doesHostnameResolveToUsableAddress(const std::string& hostname) { | 
|  | socketInitialize(); | 
|  | struct addrinfo hints {}; | 
|  | memset(&hints, 0, sizeof(hints)); | 
|  | hints.ai_family = AF_UNSPEC; | 
|  | hints.ai_socktype = SOCK_STREAM; | 
|  | struct addrinfo* result = nullptr; | 
|  | auto rv = getaddrinfo(hostname.c_str(), nullptr, &hints, &result); | 
|  | if (rv < 0) { | 
|  | return false; | 
|  | } | 
|  | struct addrinfo* rp = nullptr; | 
|  | for (rp = result; rp != nullptr; rp = rp->ai_next) { | 
|  | auto fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); | 
|  | if (fd == -1) { | 
|  | continue; | 
|  | } | 
|  | rv = bind(fd, rp->ai_addr, rp->ai_addrlen); | 
|  | #ifdef _WIN32 | 
|  | closesocket(fd); | 
|  | #else | 
|  | close(fd); | 
|  | #endif | 
|  | if (rv == -1) { | 
|  | continue; | 
|  | } | 
|  | break; | 
|  | } | 
|  | freeaddrinfo(result); | 
|  | return rp != nullptr; | 
|  | } | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: | 
|  | createDeviceForInterface(const std::string& interface_name) { | 
|  | return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface_name); | 
|  | } | 
|  |  | 
|  | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: | 
|  | createDeviceForHostname(const std::string& hostname) { | 
|  | TORCH_CHECK( | 
|  | doesHostnameResolveToUsableAddress(hostname), | 
|  | "Cannot resolve ", | 
|  | hostname, | 
|  | " to a (local) address"); | 
|  | return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname); | 
|  | } | 
|  |  | 
|  | #if defined(__linux__) || defined(_WIN32) | 
|  | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: | 
|  | createDefaultDevice() { | 
|  | // Use the hostname to resolve the network address to | 
|  | // use. Note: if the hostname does not resolve to an address (e.g. | 
|  | // because of misconfigured /etc/hosts file), this will not work. | 
|  | socketInitialize(); | 
|  | std::array<char, HOST_NAME_MAX> hostname{}; | 
|  | auto rv = gethostname(hostname.data(), HOST_NAME_MAX); | 
|  | if (rv != 0) { | 
|  | throw std::system_error(errno, std::system_category()); | 
|  | } | 
|  |  | 
|  | // Use this machine's hostname if it resolves to an address. | 
|  | if (doesHostnameResolveToUsableAddress(hostname.data())) { | 
|  | return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.data()); | 
|  | } | 
|  |  | 
|  | // Otherwise, use the loopback address. | 
|  | TORCH_WARN_ONCE( | 
|  | "Unable to resolve hostname to a (local) address. ", | 
|  | "Using the loopback address as fallback. ", | 
|  | "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); | 
|  | return createDeviceForHostname(kLoopbackAddress); | 
|  | } | 
|  | #endif | 
|  |  | 
|  | #ifdef __APPLE__ | 
|  | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: | 
|  | createDefaultDevice() { | 
|  | // Use the hostname to resolve the network address to | 
|  | // use. Note: if the hostname does not resolve to an address (e.g. | 
|  | // because of misconfigured /etc/hosts file), this will not work. | 
|  | const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX); | 
|  | auto hostname = std::unique_ptr<char[]>(new char[hostNameMax]); | 
|  | auto rv = gethostname(hostname.get(), hostNameMax); | 
|  | if (rv != 0) { | 
|  | throw std::system_error(errno, std::system_category()); | 
|  | } | 
|  |  | 
|  | // Use this machine's hostname if it resolves to an address. | 
|  | if (doesHostnameResolveToUsableAddress(hostname.get())) { | 
|  | return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.get()); | 
|  | } | 
|  |  | 
|  | // Otherwise, use the loopback address. | 
|  | TORCH_WARN_ONCE( | 
|  | "Unable to resolve hostname to a (local) address. ", | 
|  | "Using the loopback address as fallback. ", | 
|  | "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); | 
|  | return createDeviceForHostname(kLoopbackAddress); | 
|  | } | 
|  | #endif | 
|  |  | 
|  | ProcessGroupGloo::ProcessGroupGloo( | 
|  | const c10::intrusive_ptr<Store>& store, | 
|  | int rank, | 
|  | int size, | 
|  | c10::intrusive_ptr<Options> options) | 
|  | : Backend(rank, size), | 
|  | store_(new GlooStore(store)), | 
|  | options_(options), | 
|  | stop_(false), | 
|  | collectiveCounter_(0) { | 
|  | auto& devices = options->devices; | 
|  | if (devices.empty()) { | 
|  | TORCH_CHECK(false, "No device(s) specified"); | 
|  | } | 
|  |  | 
|  | // Create and connect a context for every device. | 
|  | // | 
|  | // Note that the same device can be specified multiple times, either | 
|  | // the same object, or the same logical device as different objects. | 
|  | // Either mode is fine and only has performance implications. | 
|  | // | 
|  | // Using the same object multiple times means all contexts share a | 
|  | // single I/O thread. If you use different objects for the same | 
|  | // logical device they will have independent I/O threads. The latter | 
|  | // option is needed if you have a fast NIC that cannot be saturated | 
|  | // by a single I/O thread. | 
|  | // | 
|  | contexts_.reserve(options->devices.size()); | 
|  | for (const auto i : c10::irange(options->devices.size())) { | 
|  | auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_); | 
|  | auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_); | 
|  | context->setTimeout(options->timeout); | 
|  | context->connectFullMesh(store, options->devices[i]); | 
|  | contexts_.push_back(std::move(context)); | 
|  | } | 
|  |  | 
|  | // Every worker thread stores the AsyncWork object it's currently | 
|  | // working on in the workInProgress_ vector. It must have size equal | 
|  | // to the number of workers such that they can simply index into it | 
|  | // using the worker index they are started with. | 
|  | workInProgress_.resize(options->threads); | 
|  |  | 
|  | threads_.resize(options->threads); | 
|  | for (const auto i : c10::irange(threads_.size())) { | 
|  | threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i); | 
|  | } | 
|  |  | 
|  | init(); | 
|  | } | 
|  |  | 
|  | ProcessGroupGloo::~ProcessGroupGloo() { | 
|  | std::unique_lock<std::mutex> lock(workMutex_); | 
|  | workConsumeCV_.wait(lock, [&] { return workQueue_.empty(); }); | 
|  |  | 
|  | // Queue is empty, signal stop | 
|  | stop_ = true; | 
|  |  | 
|  | // Release lock to allow threads to terminate | 
|  | lock.unlock(); | 
|  |  | 
|  | workProduceCV_.notify_all(); | 
|  |  | 
|  | // Wait for worker threads to terminate | 
|  | for (auto& thread : threads_) { | 
|  | thread.join(); | 
|  | } | 
|  | } | 
|  |  | 
|  | uint32_t ProcessGroupGloo::nextTag() { | 
|  | return collectiveCounter_++; | 
|  | } | 
|  |  | 
|  | std::shared_ptr<::gloo::Context> ProcessGroupGloo::getContext(uint32_t tag) { | 
|  | return contexts_[tag % contexts_.size()]; | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::runLoop(int workerIndex) { | 
|  | std::unique_lock<std::mutex> lock(workMutex_); | 
|  |  | 
|  | while (!stop_) { | 
|  | if (workQueue_.empty()) { | 
|  | workProduceCV_.wait(lock); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | auto work = std::move(workQueue_.front()); | 
|  | workQueue_.pop_front(); | 
|  | workInProgress_[workerIndex] = work; | 
|  | lock.unlock(); | 
|  |  | 
|  | // Notify after releasing the lock so that the waiter | 
|  | // does not immediately block. | 
|  | workConsumeCV_.notify_one(); | 
|  |  | 
|  | AsyncWork::execute(std::move(work)); | 
|  | lock.lock(); | 
|  | workInProgress_[workerIndex].reset(); | 
|  | } | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) { | 
|  | std::unique_lock<std::mutex> lock(workMutex_); | 
|  | // Bump collective counter | 
|  | if (sequenceNum_) { | 
|  | sequenceNum_->increment(); | 
|  | } | 
|  | workQueue_.push_back(std::move(work)); | 
|  | lock.unlock(); | 
|  |  | 
|  | // Notify after releasing the lock so that the waiter | 
|  | // does not immediately block. | 
|  | workProduceCV_.notify_one(); | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncBroadcastWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | int rootRank, | 
|  | int rootTensor, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:broadcast", inputs), | 
|  | context(context), | 
|  | inputs(inputs), | 
|  | rootRank(rootRank), | 
|  | rootTensor(rootTensor), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<at::Tensor> inputs; | 
|  | const int rootRank; | 
|  | const int rootTensor; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void broadcast(at::Tensor& tensor) { | 
|  | const auto& scalarType = tensor.scalar_type(); | 
|  | gloo::BroadcastOptions opts(context); | 
|  | opts.setRoot(rootRank); | 
|  | opts.setTag(tag); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor); | 
|  | gloo::broadcast(opts); | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | broadcast(inputs[rootTensor]); | 
|  |  | 
|  | // Copy to non-root tensors | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | if (i == static_cast<size_t>(rootTensor)) { | 
|  | continue; | 
|  | } | 
|  | inputs[i].copy_(inputs[rootTensor]); | 
|  | } | 
|  | } | 
|  | }; | 
|  |  | 
|  | class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { | 
|  | public: | 
|  | AsyncBroadcastCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | int rootRank, | 
|  | int rootTensor, | 
|  | uint32_t tag) | 
|  | : AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag) { | 
|  | initializeStreamsEvents(inputs, streams, events); | 
|  |  | 
|  | // Create pinned host side tensors. | 
|  | tmp = pinnedLike(inputs[rootTensor]); | 
|  | c10::OptionalStreamGuard guard; | 
|  | if (context->rank == rootRank) { | 
|  | guard.reset_stream(streams[rootTensor]); | 
|  | tmp.copy_(inputs[rootTensor], /* non_blocking */ true); | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operation if applicable. | 
|  | if (context->rank == rootRank) { | 
|  | streams[rootTensor].synchronize(); | 
|  | } | 
|  |  | 
|  | // Run broadcast on host side tensors. | 
|  | broadcast(tmp); | 
|  |  | 
|  | // Kick off copy back to the CUDA tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(streams[i]); | 
|  | inputs[i].copy_(tmp, /* non_blocking */ true); | 
|  | events[i].record(streams[i]); | 
|  | } | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | c10::Device device = inputs[i].device(); | 
|  | events[i].block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  | } | 
|  |  | 
|  | at::Tensor tmp; | 
|  | std::vector<c10::Stream> streams; | 
|  | std::vector<c10::Event> events; | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::broadcast( | 
|  | std::vector<at::Tensor>& inputs, | 
|  | const BroadcastOptions& opts) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::broadcast: " + msg); | 
|  | }; | 
|  |  | 
|  | assertRootRank(invalidArgument, opts.rootRank, size_); | 
|  | assertRootTensor(invalidArgument, opts.rootTensor, inputs.size()); | 
|  | assertDense(invalidArgument, inputs); | 
|  | assertTypeAndSizesMatch(invalidArgument, inputs); | 
|  |  | 
|  | const auto& device = inputs[0].device(); | 
|  | switch (device.type()) { | 
|  | case at::kCPU: | 
|  | break; | 
|  | case at::kCUDA: | 
|  | // If the user gave us a CUDA tensor then CUDA must be loaded. | 
|  | TORCH_INTERNAL_ASSERT(at::hasCUDA()); | 
|  | break; | 
|  | default: | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<AsyncBroadcastWork> work; | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | if (device.type() == at::kCPU) { | 
|  | work = c10::make_intrusive<AsyncBroadcastWork>( | 
|  | std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); | 
|  | } else if (device.type() == at::kCUDA) { | 
|  | work = c10::make_intrusive<AsyncBroadcastCUDAWork>( | 
|  | std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); | 
|  | } else { | 
|  | TORCH_CHECK(false, "Invalid backend"); | 
|  | } | 
|  |  | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncAllreduceWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | ReduceOp reduceOp, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:all_reduce", inputs), | 
|  | context(context), | 
|  | inputs(inputs), | 
|  | reduceOp(reduceOp), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<at::Tensor> inputs; | 
|  | const ReduceOp reduceOp; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void allreduce(std::vector<at::Tensor>& tensors) { | 
|  | const auto& scalarType = tensors[0].scalar_type(); | 
|  | gloo::AllreduceOptions opts(context); | 
|  | opts.setReduceFunction(getFunction(scalarType, reduceOp)); | 
|  | opts.setTag(tag); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors); | 
|  | gloo::allreduce(opts); | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | allreduce(inputs); | 
|  | } | 
|  |  | 
|  | template <typename T> | 
|  | void getFunction(gloo::AllreduceOptions::Func& fn, const ReduceOp op) { | 
|  | fn = toFunction<T>(op); | 
|  | } | 
|  |  | 
|  | gloo::AllreduceOptions::Func getFunction( | 
|  | const at::ScalarType& dtype, | 
|  | const ReduceOp op) { | 
|  | gloo::AllreduceOptions::Func fn; | 
|  | GENERATE_ALL_TYPES(dtype, getFunction, fn, op); | 
|  | return fn; | 
|  | } | 
|  | }; | 
|  |  | 
|  | class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork { | 
|  | public: | 
|  | AsyncAllreduceCoalescedWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | ReduceOp reduceOp, | 
|  | uint32_t tag) | 
|  | : AsyncAllreduceWork(context, inputs, reduceOp, tag) {} | 
|  |  | 
|  | void run() override { | 
|  | allreduceCoalesced(inputs); | 
|  | } | 
|  |  | 
|  | private: | 
|  | void allreduceCoalesced(std::vector<at::Tensor>& tensors) { | 
|  | // reduce coalesced, flattened tensors. | 
|  | at::Tensor coalescedTensor = flattenDenseTensors(tensors); | 
|  | std::vector<at::Tensor> allreduceInput = {coalescedTensor}; | 
|  | allreduce(allreduceInput); | 
|  |  | 
|  | // separate and reshape tensors. | 
|  | size_t offset = 0; | 
|  | for (at::Tensor& tensor : tensors) { | 
|  | const int64_t tensorNumel = tensor.numel(); | 
|  | const c10::IntArrayRef tensorShape = tensor.sizes(); | 
|  | tensor.copy_(coalescedTensor.slice(0, offset, offset + tensorNumel) | 
|  | .view(tensorShape)); | 
|  | offset += tensorNumel; | 
|  | } | 
|  | } | 
|  | }; | 
|  |  | 
|  | class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncSparseAllreduceWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:sparse_all_reduce", inputs), | 
|  | context(context), | 
|  | inputs(inputs), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<at::Tensor> inputs; | 
|  | const uint32_t tag; | 
|  |  | 
|  | // We share dimensionality about the sparse tensors before collecting | 
|  | // their contents. We assume here that the maximum number of sparse | 
|  | // and dense dimensions is 4. This is stored in a contiguous piece of | 
|  | // memory so that we can easily run allgather on it. | 
|  | // | 
|  | // The layout of this memory is as follows: | 
|  | // | 
|  | //   - [0:4]: sparse dims | 
|  | //   - [4:8]: dense dims | 
|  | //   -   [8]: nnz | 
|  | // | 
|  | class SparseTensorMetadata { | 
|  | public: | 
|  | static constexpr auto dim = 9; | 
|  |  | 
|  | // Construct from an existing metadata tensor to facilitate structured | 
|  | // access to metadata from peers, after gathering it. | 
|  | explicit SparseTensorMetadata(at::Tensor metadata) | 
|  | : metadata_(metadata), data_(metadata_.data_ptr<int64_t>()) { | 
|  | AT_ASSERT(metadata.scalar_type() == at::kLong); | 
|  | AT_ASSERT(metadata.dim() == 1); | 
|  | AT_ASSERT(metadata.size(0) == dim); | 
|  | } | 
|  |  | 
|  | // Populate the metadata. | 
|  | void populate_from_sparse_tensor(const at::Tensor& tensor) { | 
|  | const auto sparse_dim = tensor.sparse_dim(); | 
|  | AT_ASSERT(sparse_dim <= 4); | 
|  | for (const auto i : c10::irange(4)) { | 
|  | if (i < sparse_dim) { | 
|  | data_[i] = tensor.size(i); | 
|  | } | 
|  | } | 
|  | const auto dense_dim = tensor.dense_dim(); | 
|  | AT_ASSERT(dense_dim <= 4); | 
|  | for (const auto i : c10::irange(4)) { | 
|  | if (i < dense_dim) { | 
|  | data_[i + 4] = tensor.size(sparse_dim + i); | 
|  | } | 
|  | } | 
|  | data_[8] = tensor._nnz(); | 
|  | } | 
|  |  | 
|  | std::vector<int64_t> sizes() const { | 
|  | std::vector<int64_t> sizes; | 
|  | // Sparse sizes | 
|  | for (const auto i : c10::irange(4)) { | 
|  | if (data_[i] <= 0) { | 
|  | break; | 
|  | } | 
|  | sizes.push_back(data_[i]); | 
|  | } | 
|  | // Dense sizes | 
|  | for (const auto i : c10::irange(4, 8)) { | 
|  | if (data_[i] <= 0) { | 
|  | break; | 
|  | } | 
|  | sizes.push_back(data_[i]); | 
|  | } | 
|  | return sizes; | 
|  | } | 
|  |  | 
|  | int64_t nnz() const { | 
|  | return data_[8]; | 
|  | } | 
|  |  | 
|  | protected: | 
|  | at::Tensor metadata_; | 
|  | int64_t* data_; | 
|  | }; | 
|  |  | 
|  | // Sparse allreduce is implemented with allgather on indices and values. | 
|  | // Every process then sums the resulting sparse tensors locally. | 
|  | // The nnz for sparse tensors may be different across processes, so first | 
|  | // we run allgather on the nnz, and then allgather with max(nnz). | 
|  | at::Tensor allreduce(std::vector<at::Tensor>& tensors) { | 
|  | // TODO: This is a massive hack!  There is some confusion about | 
|  | // Variable/Tensor inside the body of this function.  Turning off | 
|  | // grad smooths over the confusion for now.  This fixes | 
|  | // test/test_c10d_gloo.py ProcessGroupGlooTest.test_sparse_allreduce_basics | 
|  | // | 
|  | // The correct fix is to stop allocating tensors that are not variables, | 
|  | // but to conveniently do this c10d must depend on torch not ATen | 
|  | at::AutoDispatchBelowAutograd guard; | 
|  | auto input = tensors[0]; | 
|  |  | 
|  | // Perform local reduction if we have multiple inputs. | 
|  | for (const auto i : c10::irange(1, tensors.size())) { | 
|  | input += tensors[i]; | 
|  | } | 
|  |  | 
|  | // Need to coalesce before we can access indices and values. | 
|  | input = input.coalesce(); | 
|  |  | 
|  | // Gather metadata information from all ranks. | 
|  | auto metadata = allgather_metadata(input); | 
|  |  | 
|  | // Sanity check dimensionality across ranks. | 
|  | { | 
|  | const auto expected = metadata[context->rank].sizes(); | 
|  | for (const auto i : c10::irange(context->size)) { | 
|  | if (i == context->rank) { | 
|  | continue; | 
|  | } | 
|  | const auto actual = metadata[i].sizes(); | 
|  | TORCH_CHECK(actual == expected, "Sparse dimensions do not match"); | 
|  | } | 
|  | } | 
|  |  | 
|  | // Gather all indices and all values. | 
|  | auto indices = allgather_indices(input, metadata); | 
|  | auto values = allgather_values(input, metadata); | 
|  |  | 
|  | // Perform global reduction. | 
|  | AT_ASSERT(static_cast<int>(indices.size()) == context->size); | 
|  | AT_ASSERT(static_cast<int>(values.size()) == context->size); | 
|  | auto output = at::sparse_coo_tensor( | 
|  | indices[0], values[0], input.sizes(), input.options()); | 
|  | for (const auto i : c10::irange(1, context->size)) { | 
|  | output += at::sparse_coo_tensor( | 
|  | indices[i], values[i], input.sizes(), input.options()); | 
|  | } | 
|  |  | 
|  | // Coalesce for good measure. | 
|  | return output.coalesce(); | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | auto output = allreduce(inputs); | 
|  |  | 
|  | // This copy is needed when we run a multi-gpu version of reduce (multiple | 
|  | // inputs per rank). | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | inputs[i].copy_(output); | 
|  | } | 
|  | } | 
|  |  | 
|  | private: | 
|  | std::vector<SparseTensorMetadata> allgather_metadata( | 
|  | const at::Tensor& tensor) { | 
|  | auto buffer = | 
|  | at::zeros({context->size, SparseTensorMetadata::dim}, at::kLong); | 
|  |  | 
|  | // Prepare metadata vector (1 entry per rank) | 
|  | std::vector<SparseTensorMetadata> metadata; | 
|  | metadata.reserve(context->size); | 
|  | for (const auto i : c10::irange(context->size)) { | 
|  | metadata.emplace_back(buffer.select(0, i)); | 
|  | } | 
|  |  | 
|  | // Populate data for this rank | 
|  | metadata[context->rank].populate_from_sparse_tensor(tensor); | 
|  |  | 
|  | // Allgather metadata | 
|  | gloo::AllgatherOptions opts(context); | 
|  | opts.setOutput(buffer.data_ptr<int64_t>(), buffer.numel()); | 
|  | opts.setTag(tag); | 
|  | gloo::allgather(opts); | 
|  |  | 
|  | return metadata; | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> allgather_indices( | 
|  | const at::Tensor& tensor, | 
|  | const std::vector<SparseTensorMetadata>& metadata) { | 
|  | const auto sparseDim = tensor.sparse_dim(); | 
|  |  | 
|  | std::vector<size_t> counts(context->size); | 
|  | int64_t totalSize = 0; | 
|  | for (const auto i : c10::irange(metadata.size())) { | 
|  | counts[i] = metadata[i].nnz() * sparseDim; | 
|  | totalSize += counts[i]; | 
|  | } | 
|  |  | 
|  | auto output = at::empty({totalSize}, at::kLong); | 
|  |  | 
|  | // tensors copied from cuda may not be contiguous, get a contiguous | 
|  | // tensor before use its data_ptr | 
|  | auto input = tensor.indices().contiguous(); | 
|  |  | 
|  | // Allgatherv indices. | 
|  | gloo::AllgathervOptions opts(context); | 
|  | opts.setInput(input.data_ptr<int64_t>(), input.numel()); | 
|  | opts.setOutput(output.data_ptr<int64_t>(), counts); | 
|  | opts.setTag(tag); | 
|  | gloo::allgatherv(opts); | 
|  |  | 
|  | // Compile indices tensor per rank. | 
|  | std::vector<at::Tensor> indices; | 
|  | indices.reserve(metadata.size()); | 
|  | size_t offset = 0; | 
|  | for (const auto& i : metadata) { | 
|  | const auto nnz = i.nnz(); | 
|  | const auto numel = sparseDim * nnz; | 
|  | indices.push_back( | 
|  | output.narrow(0, offset, numel).reshape({sparseDim, nnz})); | 
|  | offset += numel; | 
|  | } | 
|  |  | 
|  | return indices; | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> allgather_values( | 
|  | const at::Tensor& tensor, | 
|  | const std::vector<SparseTensorMetadata>& metadata) { | 
|  | // There are nnz #dense_dim()-dimensional tensors per rank. | 
|  | const auto valueShape = tensor.sizes().slice(tensor.sparse_dim()); | 
|  | size_t denseNumel = 1; | 
|  | for (auto dim : valueShape) { | 
|  | denseNumel *= dim; | 
|  | } | 
|  |  | 
|  | std::vector<size_t> counts(context->size); | 
|  | int64_t totalSize = 0; | 
|  | for (const auto i : c10::irange(metadata.size())) { | 
|  | counts[i] = metadata[i].nnz() * denseNumel; | 
|  | totalSize += counts[i]; | 
|  | } | 
|  |  | 
|  | auto output = at::empty({totalSize}, tensor.scalar_type()); | 
|  |  | 
|  | // Allgatherv indices. | 
|  | gloo::AllgathervOptions opts(context); | 
|  | // tensors copied from cuda may not be contiguous, get a contiguous | 
|  | // tensor before use its data_ptr | 
|  | at::Tensor valueTensor = tensor.values().contiguous(); | 
|  | GENERATE_ALL_TYPES(valueTensor.scalar_type(), setInput, opts, valueTensor); | 
|  | GENERATE_ALL_TYPES( | 
|  | valueTensor.scalar_type(), setOutput, opts, output, counts); | 
|  | opts.setTag(tag); | 
|  | gloo::allgatherv(opts); | 
|  |  | 
|  | // Compile values tensor per rank. | 
|  | std::vector<at::Tensor> values; | 
|  | values.reserve(metadata.size()); | 
|  | size_t offset = 0; | 
|  | for (const auto& i : metadata) { | 
|  | const auto nnz = i.nnz(); | 
|  | const auto numel = denseNumel * nnz; | 
|  | auto tensorShape = std::vector<int64_t>({(int64_t)nnz}); | 
|  | std::copy( | 
|  | valueShape.begin(), | 
|  | valueShape.end(), | 
|  | std::back_inserter(tensorShape)); | 
|  | values.push_back(output.narrow(0, offset, numel).reshape(tensorShape)); | 
|  | offset += numel; | 
|  | } | 
|  |  | 
|  | return values; | 
|  | } | 
|  | }; | 
|  |  | 
|  | class AsyncAllreduceCUDAWork : public AsyncAllreduceWork { | 
|  | public: | 
|  | AsyncAllreduceCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | ReduceOp reduceOp, | 
|  | uint32_t tag) | 
|  | : AsyncAllreduceWork(context, inputs, reduceOp, tag) { | 
|  | initializeStreamsEvents(inputs, streams, events); | 
|  |  | 
|  | // Kick off copy from CUDA tensors to pinned CPU tensors. | 
|  | tmp.reserve(inputs.size()); | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(streams[i]); | 
|  | tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operations. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | streams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | // Run allreduce on host side tensors. | 
|  | allreduce(tmp); | 
|  |  | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(streams[i]); | 
|  | inputs[i].copy_(tmp[i], /* non_blocking */ true); | 
|  | events[i].record(streams[i]); | 
|  | } | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | c10::Device device = inputs[i].device(); | 
|  | events[i].block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> tmp; | 
|  | std::vector<c10::Stream> streams; | 
|  | std::vector<c10::Event> events; | 
|  | }; | 
|  |  | 
|  | class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { | 
|  | public: | 
|  | AsyncSparseAllreduceCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | uint32_t tag) | 
|  | : AsyncSparseAllreduceWork(context, inputs, tag) { | 
|  | initializeStreamsEvents(inputs, streams, events); | 
|  |  | 
|  | // Kick off copy from CUDA tensors to CPU tensors. | 
|  | // Note that both coalescing the sparse tensor and copying it to CPU | 
|  | // memory must be performed asynchronously, or we block the caller. | 
|  | tmp.reserve(inputs.size()); | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(streams[i]); | 
|  | tmp.push_back( | 
|  | inputs[i].coalesce().to(at::DeviceType::CPU, /*non_blocking=*/true)); | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operations. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | streams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | // Run allreduce on host side tensors. | 
|  | auto output = allreduce(tmp); | 
|  |  | 
|  | // Kick off copy back to the CUDA tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(streams[i]); | 
|  | inputs[i].copy_(output, /*non_blocking=*/true); | 
|  | events[i].record(streams[i]); | 
|  | } | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | c10::Device device = inputs[i].device(); | 
|  | events[i].block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> tmp; | 
|  | std::vector<c10::Stream> streams; | 
|  | std::vector<c10::Event> events; | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce( | 
|  | std::vector<at::Tensor>& inputs, | 
|  | const AllreduceOptions& opts) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::allreduce: " + msg); | 
|  | }; | 
|  |  | 
|  | assertNonEmpty(invalidArgument, inputs); | 
|  | assertLayoutMatch(invalidArgument, inputs); | 
|  | assertTypeAndSizesMatch(invalidArgument, inputs); | 
|  |  | 
|  | const auto& device = inputs[0].device(); | 
|  | switch (device.type()) { | 
|  | case at::kCPU: | 
|  | break; | 
|  | case at::kCUDA: | 
|  | // If the user gave us a CUDA tensor then CUDA must be loaded. | 
|  | TORCH_INTERNAL_ASSERT(at::hasCUDA()); | 
|  | break; | 
|  | default: | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  |  | 
|  | const auto& layout = inputs[0].layout(); | 
|  | if (layout == c10::kSparse && opts.reduceOp != ReduceOp::SUM) { | 
|  | invalidArgument( | 
|  | "unsupported reduction operation " | 
|  | "(allreduce of sparse tensors only works with ReduceOp.SUM)"); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<AsyncWork> work; | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | if (device.type() == at::kCPU) { | 
|  | if (layout == c10::kStrided) { | 
|  | work = c10::make_intrusive<AsyncAllreduceWork>( | 
|  | std::move(context), inputs, opts.reduceOp, tag); | 
|  | } else if (layout == c10::kSparse) { | 
|  | work = c10::make_intrusive<AsyncSparseAllreduceWork>( | 
|  | std::move(context), inputs, tag); | 
|  | } else { | 
|  | invalidArgument("unsupported layout"); | 
|  | } | 
|  | } else if (device.type() == at::kCUDA) { | 
|  | if (layout == c10::kStrided) { | 
|  | work = c10::make_intrusive<AsyncAllreduceCUDAWork>( | 
|  | std::move(context), inputs, opts.reduceOp, tag); | 
|  | } else if (layout == c10::kSparse) { | 
|  | work = c10::make_intrusive<AsyncSparseAllreduceCUDAWork>( | 
|  | std::move(context), inputs, tag); | 
|  | } else { | 
|  | invalidArgument("unsupported layout"); | 
|  | } | 
|  | } else { | 
|  | TORCH_CHECK(false, "Invalid backend"); | 
|  | } | 
|  |  | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce_coalesced( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | const AllreduceCoalescedOptions& opts) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::allreduce_coalesced: " + msg); | 
|  | }; | 
|  | assertNonEmpty(invalidArgument, tensors); | 
|  |  | 
|  | // tensors will be flattened and concatenated (coalesced). This means that | 
|  | // input | 
|  | // tensors must have the same device, layout and type. | 
|  | assertLayoutMatch(invalidArgument, tensors); | 
|  | if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) { | 
|  | return t.options().type_equal(tensors[0].options()); | 
|  | })) { | 
|  | invalidArgument("tensors must all have the same type"); | 
|  | } | 
|  | if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) { | 
|  | return t.device() == tensors[0].device(); | 
|  | })) { | 
|  | invalidArgument("tensors must all be on the same device"); | 
|  | } | 
|  |  | 
|  | const c10::Device& device = tensors[0].device(); | 
|  | const c10::Layout& layout = tensors[0].layout(); | 
|  |  | 
|  | // invalid arguments are detected early here before any calls to nextTag() | 
|  | // which result in the collectiveCounter_ being incremented. | 
|  | switch (device.type()) { | 
|  | case c10::kCPU: | 
|  | break; | 
|  | default: | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  |  | 
|  | switch (layout) { | 
|  | case c10::kStrided: | 
|  | break; | 
|  | default: | 
|  | invalidArgument("unsupported layout"); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<AsyncWork> work; | 
|  | const uint32_t tag = nextTag(); | 
|  | std::shared_ptr<gloo::Context> context = getContext(tag); | 
|  | if (device.type() == c10::kCPU) { | 
|  | if (layout == c10::kStrided) { | 
|  | work = c10::make_intrusive<AsyncAllreduceCoalescedWork>( | 
|  | std::move(context), tensors, opts.reduceOp, tag); | 
|  | } else { | 
|  | invalidArgument("unsupported layout"); | 
|  | } | 
|  | } else { | 
|  | TORCH_CHECK(false, "Invalid backend"); | 
|  | } | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncReduceWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | int rootRank, | 
|  | int rootTensor, | 
|  | ReduceOp reduceOp, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:reduce", inputs), | 
|  | context(context), | 
|  | inputs(inputs), | 
|  | rootRank(rootRank), | 
|  | rootTensor(rootTensor), | 
|  | reduceOp(reduceOp), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<at::Tensor> inputs; | 
|  | const int rootRank; | 
|  | const int rootTensor; | 
|  | const ReduceOp reduceOp; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void reduce(std::vector<at::Tensor>& tensors) { | 
|  | const auto& scalarType = tensors[0].scalar_type(); | 
|  | gloo::ReduceOptions opts(context); | 
|  | opts.setRoot(rootRank); | 
|  | opts.setTag(tag); | 
|  | opts.setReduceFunction(getFunction(scalarType, reduceOp)); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensors[0]); | 
|  | gloo::reduce(opts); | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | reduce(inputs); | 
|  | } | 
|  |  | 
|  | protected: | 
|  | template <typename T> | 
|  | void getFunction(gloo::ReduceOptions::Func& fn, const ReduceOp op) { | 
|  | fn = toFunction<T>(op); | 
|  | } | 
|  |  | 
|  | gloo::ReduceOptions::Func getFunction( | 
|  | const at::ScalarType& dtype, | 
|  | const ReduceOp op) { | 
|  | gloo::ReduceOptions::Func fn; | 
|  | GENERATE_ALL_TYPES(dtype, getFunction, fn, op); | 
|  | return fn; | 
|  | } | 
|  | }; | 
|  |  | 
|  | class AsyncReduceCUDAWork : public AsyncReduceWork { | 
|  | public: | 
|  | AsyncReduceCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | int rootRank, | 
|  | int rootTensor, | 
|  | ReduceOp reduceOp, | 
|  | uint32_t tag) | 
|  | : AsyncReduceWork(context, inputs, rootRank, rootTensor, reduceOp, tag) { | 
|  | initializeStreamsEvents(inputs, streams, events); | 
|  |  | 
|  | // Kick off copy from CUDA tensors to pinned CPU tensors. | 
|  | tmp.reserve(inputs.size()); | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(streams[i]); | 
|  | tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operations. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | streams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | // Run reduce on host side tensors. | 
|  | reduce(tmp); | 
|  |  | 
|  | // Kick off copy back to the CUDA tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(streams[i]); | 
|  | inputs[i].copy_(tmp[i], /* non_blocking */ true); | 
|  | events[i].record(streams[i]); | 
|  | } | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | c10::Device device = inputs[i].device(); | 
|  | events[i].block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> tmp; | 
|  | std::vector<c10::Stream> streams; | 
|  | std::vector<c10::Event> events; | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::reduce( | 
|  | std::vector<at::Tensor>& inputs, | 
|  | const ReduceOptions& opts) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::reduce: " + msg); | 
|  | }; | 
|  |  | 
|  | assertRootRank(invalidArgument, opts.rootRank, size_); | 
|  | assertRootTensor(invalidArgument, opts.rootTensor, inputs.size()); | 
|  | assertSingleElement(invalidArgument, inputs); | 
|  | assertDense(invalidArgument, inputs); | 
|  |  | 
|  | const auto& device = inputs[0].device(); | 
|  | switch (device.type()) { | 
|  | case at::kCPU: | 
|  | break; | 
|  | case at::kCUDA: | 
|  | // If the user gave us a CUDA tensor then CUDA must be loaded. | 
|  | TORCH_INTERNAL_ASSERT(at::hasCUDA()); | 
|  | break; | 
|  | default: | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<AsyncReduceWork> work; | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | if (device.type() == at::kCPU) { | 
|  | work = c10::make_intrusive<AsyncReduceWork>( | 
|  | std::move(context), | 
|  | inputs, | 
|  | opts.rootRank, | 
|  | opts.rootTensor, | 
|  | opts.reduceOp, | 
|  | tag); | 
|  | } else if (device.type() == at::kCUDA) { | 
|  | work = c10::make_intrusive<AsyncReduceCUDAWork>( | 
|  | std::move(context), | 
|  | inputs, | 
|  | opts.rootRank, | 
|  | opts.rootTensor, | 
|  | opts.reduceOp, | 
|  | tag); | 
|  | } else { | 
|  | TORCH_CHECK(false, "Invalid backend"); | 
|  | } | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncAllgatherWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork(outputs, "gloo:all_gather", inputs), | 
|  | context(context), | 
|  | outputs(outputs), | 
|  | inputs(inputs), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<std::vector<at::Tensor>> outputs; | 
|  | std::vector<at::Tensor> inputs; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void allgather( | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs) { | 
|  | const auto& scalarType = inputs[0].scalar_type(); | 
|  | gloo::AllgatherOptions opts(context); | 
|  | opts.setTag(tag); | 
|  |  | 
|  | // Use single flattened input tensor. | 
|  | at::Tensor flatInputTensor = flattenDenseTensors(inputs); | 
|  | GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); | 
|  |  | 
|  | // Use single flat output tensor. | 
|  | // The first dimension corresponds to the index into outputs[N], | 
|  | // so copying into the actual output later is easy. | 
|  | at::Tensor flatOutputTensor = newLikeFlat(outputs[0]); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); | 
|  | gloo::allgather(opts); | 
|  |  | 
|  | // Unflatten into output tensors. | 
|  | for (auto& outputgroup : outputs) { | 
|  | for (const auto j : c10::irange(outputgroup.size())) { | 
|  | outputgroup[j].copy_(flatOutputTensor[j]); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | allgather(outputs, inputs); | 
|  | } | 
|  | }; | 
|  |  | 
|  | // Note: current CUDA implementation holds the assumption that the | 
|  | // tensors in the nested output tensor vectors are on the same device. | 
|  | class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { | 
|  | public: | 
|  | AsyncAllgatherCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | uint32_t tag) | 
|  | : AsyncAllgatherWork(context, outputs, inputs, tag) { | 
|  | initializeStreamsEvents(inputs, inputStreams, inputEvents); | 
|  | initializeStreamsEvents(outputs, outputStreams, outputEvents); | 
|  |  | 
|  | // Kick off copy from CUDA tensors to pinned CPU tensors. | 
|  | tmpInputs.reserve(inputs.size()); | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(inputStreams[i]); | 
|  | tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); | 
|  | } | 
|  |  | 
|  | tmpOutputs.resize(outputs.size()); | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | tmpOutputs[i].reserve(outputs[i].size()); | 
|  | for (const auto j : c10::irange(outputs[i].size())) { | 
|  | tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operations. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | inputStreams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | outputStreams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | // Run allgather on host side tensors. | 
|  | allgather(tmpOutputs, tmpInputs); | 
|  |  | 
|  | // Kick off copy back to the CUDA tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | guard.reset_stream(outputStreams[i]); | 
|  | for (const auto j : c10::irange(outputs[i].size())) { | 
|  | outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); | 
|  | } | 
|  | outputEvents[i].record(outputStreams[i]); | 
|  | } | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | c10::Device device = outputs[i][0].device(); | 
|  | outputEvents[i].block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> tmpInputs; | 
|  | std::vector<c10::Stream> inputStreams; | 
|  | std::vector<c10::Event> inputEvents; | 
|  |  | 
|  | std::vector<std::vector<at::Tensor>> tmpOutputs; | 
|  | std::vector<c10::Stream> outputStreams; | 
|  | std::vector<c10::Event> outputEvents; | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | // Note: current CUDA implementation holds the assumption that the | 
|  | // tensors in the nested output tensor vectors are on the same device. | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::allgather( | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | const AllgatherOptions& opts) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::allgather: " + msg); | 
|  | }; | 
|  |  | 
|  | if (inputs.size() == 0) { | 
|  | invalidArgument("requires non-empty input tensor list"); | 
|  | } | 
|  |  | 
|  | if (inputs.size() != outputs.size()) { | 
|  | invalidArgument( | 
|  | "requires input/output tensor lists to have the same length"); | 
|  | } | 
|  |  | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | const auto expected = inputs.size() * getSize(); | 
|  | const auto actual = outputs[i].size(); | 
|  | if (actual != expected) { | 
|  | invalidArgument( | 
|  | "invalid output tensor list at index " + std::to_string(i) + | 
|  | " (expected length " + std::to_string(expected) + ", got " + | 
|  | std::to_string(actual) + ")"); | 
|  | } | 
|  | } | 
|  |  | 
|  | assertDense(invalidArgument, inputs); | 
|  |  | 
|  | // Expect all input/output tensors to have the same type and sizes | 
|  | const auto& options = inputs[0].options(); | 
|  | const auto& sizes = inputs[0].sizes(); | 
|  | assertTypeAndSizesMatch(invalidArgument, inputs, options, sizes); | 
|  | for (const auto& output : outputs) { | 
|  | assertTypeAndSizesMatch(invalidArgument, output, options, sizes); | 
|  | } | 
|  |  | 
|  | const auto& device = inputs[0].device(); | 
|  | switch (device.type()) { | 
|  | case at::kCPU: | 
|  | break; | 
|  | case at::kCUDA: | 
|  | // If the user gave us a CUDA tensor then CUDA must be loaded. | 
|  | TORCH_INTERNAL_ASSERT(at::hasCUDA()); | 
|  | break; | 
|  | default: | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<AsyncAllgatherWork> work; | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | if (device.type() == at::kCPU) { | 
|  | work = c10::make_intrusive<AsyncAllgatherWork>( | 
|  | std::move(context), outputs, inputs, tag); | 
|  | } else if (device.type() == at::kCUDA) { | 
|  | work = c10::make_intrusive<AsyncAllgatherCUDAWork>( | 
|  | std::move(context), outputs, inputs, tag); | 
|  | } else { | 
|  | TORCH_CHECK(false, "Invalid backend"); | 
|  | } | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncAllgatherCoalescedWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<std::vector<at::Tensor>>& output_lists, | 
|  | std::vector<at::Tensor>& input_list, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork( | 
|  | output_lists, | 
|  | "gloo:all_gather", | 
|  | input_list), | 
|  | context(context), | 
|  | output_lists(output_lists), | 
|  | input_list(input_list), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<std::vector<at::Tensor>> output_lists; | 
|  | std::vector<at::Tensor> input_list; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void allgather_coalesced() { | 
|  | assert(!output_lists.empty()); | 
|  | assert(!output_lists[0].empty()); | 
|  | assert(!input_list.empty()); | 
|  |  | 
|  | const auto& scalarType = input_list[0].scalar_type(); | 
|  | gloo::AllgatherOptions opts(context); | 
|  | opts.setTag(tag); | 
|  |  | 
|  | // Use single flattened input tensor. | 
|  | at::Tensor flatInputTensor = flattenDenseTensors(input_list); | 
|  | GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); | 
|  |  | 
|  | // Compute total number of elements we need to allocate for all tensors | 
|  | // requested. | 
|  | int64_t output_numel = 0; | 
|  | for (const auto& t : output_lists[0]) { | 
|  | output_numel += t.numel(); | 
|  | } | 
|  | output_numel *= output_lists.size(); | 
|  | // Use single flat output tensor. | 
|  | at::Tensor flatOutputTensor = | 
|  | at::empty({output_numel}, output_lists[0][0].options()); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); | 
|  | gloo::allgather(opts); | 
|  |  | 
|  | int64_t current_element = 0; | 
|  | for (auto& output_list : output_lists) { | 
|  | for (auto& output_tensor : output_list) { | 
|  | output_tensor.copy_( | 
|  | flatOutputTensor.narrow(0, current_element, output_tensor.numel()) | 
|  | .reshape(output_tensor.sizes()), | 
|  | true); | 
|  | current_element += output_tensor.numel(); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | allgather_coalesced(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::allgather_coalesced( | 
|  | std::vector<std::vector<at::Tensor>>& output_lists, | 
|  | std::vector<at::Tensor>& input_list, | 
|  | const AllgatherOptions& /* unused */) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::allgather_coalesced: " + msg); | 
|  | }; | 
|  |  | 
|  | if (input_list.empty()) { | 
|  | invalidArgument("requires non-empty input tensor list"); | 
|  | } | 
|  |  | 
|  | if (output_lists.size() != getSize()) { | 
|  | invalidArgument("output lists should be equal to world size"); | 
|  | } | 
|  |  | 
|  | assertSameDevice(invalidArgument, input_list); | 
|  |  | 
|  | // Expect i'th tensor of each list from 'output_lists' match i'th tensor | 
|  | // from 'input_list' in type and size. | 
|  | for (const auto& output_list : output_lists) { | 
|  | if (output_list.size() != input_list.size()) { | 
|  | invalidArgument( | 
|  | "invalid output size: (expected length " + | 
|  | std::to_string(input_list.size()) + ", got " + | 
|  | std::to_string(output_list.size()) + ")"); | 
|  | } | 
|  | for (const auto i : c10::irange(output_list.size())) { | 
|  | const auto expected = input_list[i].sizes(); | 
|  | const auto actual = output_list[i].sizes(); | 
|  | if (actual != expected) { | 
|  | invalidArgument( | 
|  | "invalid size of output tensor at index " + std::to_string(i) + | 
|  | " (expected length " + toString(expected) + ", got " + | 
|  | toString(actual) + ")"); | 
|  | } | 
|  | if (!input_list[i].options().type_equal(output_list[i].options())) { | 
|  | invalidArgument( | 
|  | "invalid tensor type at index " + std::to_string(i) + | 
|  | " (expected " + input_list[i].toString() + ", got " + | 
|  | output_list[i].toString() + ")"); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | assertDense(invalidArgument, input_list); | 
|  |  | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | auto work = c10::make_intrusive<AsyncAllgatherCoalescedWork>( | 
|  | std::move(context), output_lists, input_list, tag); | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::_allgather_base( | 
|  | at::Tensor& /*unused */, | 
|  | at::Tensor& /*unused */, | 
|  | const AllgatherOptions& /*unused */) { | 
|  | TORCH_CHECK(false, "no support for _allgather_base in Gloo process group"); | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncGatherWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | int root, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork(outputs, "gloo:gather", inputs), | 
|  | context(context), | 
|  | outputs(outputs), | 
|  | inputs(inputs), | 
|  | root(root), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<std::vector<at::Tensor>> outputs; | 
|  | std::vector<at::Tensor> inputs; | 
|  | const int root; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void gather( | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs) { | 
|  | const auto scalarType = inputs[0].scalar_type(); | 
|  | gloo::GatherOptions opts(context); | 
|  | opts.setRoot(root); | 
|  | opts.setTag(tag); | 
|  |  | 
|  | // Set single temporary tensor on root process. | 
|  | // This is later scattered to the separate output tensors. | 
|  | at::Tensor flatOutputTensor; | 
|  | if (context->rank == root) { | 
|  | flatOutputTensor = newLikeFlat(outputs[0]); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); | 
|  | } | 
|  |  | 
|  | // Set single input tensor on all processes. | 
|  | GENERATE_ALL_TYPES(scalarType, setInput, opts, inputs[0]); | 
|  | gloo::gather(opts); | 
|  |  | 
|  | // Unflatten into output tensors on root process. | 
|  | if (context->rank == root) { | 
|  | for (const auto i : c10::irange(outputs[0].size())) { | 
|  | outputs[0][i].copy_(flatOutputTensor[i]); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | gather(outputs, inputs); | 
|  | } | 
|  | }; | 
|  |  | 
|  | // Note: current CUDA implementation holds the assumptions: | 
|  | //     - inputs.size() is 1 | 
|  | //     - outputs.size() is 1 | 
|  | //     - the size of the nested output tensors is world size, i.e., | 
|  | //       outputs[0].size, is world size | 
|  | class AsyncGatherCUDAWork : public AsyncGatherWork { | 
|  | public: | 
|  | AsyncGatherCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | int root, | 
|  | uint32_t tag) | 
|  | : AsyncGatherWork(context, outputs, inputs, root, tag) { | 
|  | initializeStreamsEvents(inputs, inputStreams, inputEvents); | 
|  | initializeStreamsEvents(outputs, outputStreams, outputEvents); | 
|  |  | 
|  | // Kick off copy from CUDA tensors to pinned CPU tensors. | 
|  | tmpInputs.reserve(inputs.size()); | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(inputStreams[i]); | 
|  | tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); | 
|  | } | 
|  |  | 
|  | tmpOutputs.resize(outputs.size()); | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | tmpOutputs[i].reserve(outputs[i].size()); | 
|  | for (const auto j : c10::irange(outputs[i].size())) { | 
|  | tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operations. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | inputStreams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | outputStreams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | // Run gather on host side tensors. | 
|  | gather(tmpOutputs, tmpInputs); | 
|  |  | 
|  | // Kick off copy back to the CUDA tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | guard.reset_stream(outputStreams[i]); | 
|  | for (const auto j : c10::irange(outputs[i].size())) { | 
|  | outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); | 
|  | } | 
|  | outputEvents[i].record(outputStreams[i]); | 
|  | } | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | c10::Device device = outputs[i][0].device(); | 
|  | outputEvents[i].block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> tmpInputs; | 
|  | std::vector<c10::Stream> inputStreams; | 
|  | std::vector<c10::Event> inputEvents; | 
|  |  | 
|  | std::vector<std::vector<at::Tensor>> tmpOutputs; | 
|  | std::vector<c10::Stream> outputStreams; | 
|  | std::vector<c10::Event> outputEvents; | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::gather( | 
|  | std::vector<std::vector<at::Tensor>>& outputs, | 
|  | std::vector<at::Tensor>& inputs, | 
|  | const GatherOptions& opts) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::gather: " + msg); | 
|  | }; | 
|  |  | 
|  | assertRootRank(invalidArgument, opts.rootRank, size_); | 
|  | assertSingleElementInput(invalidArgument, inputs); | 
|  | assertDense(invalidArgument, inputs); | 
|  |  | 
|  | if (getRank() == opts.rootRank) { | 
|  | if (outputs.size() != 1) { | 
|  | std::stringstream ss; | 
|  | ss << "requires a single-element output list containing a list with " | 
|  | << getSize() << " tensors."; | 
|  | invalidArgument(ss.str()); | 
|  | } else if (outputs[0].size() != static_cast<size_t>(getSize())) { | 
|  | std::stringstream ss; | 
|  | ss << "Incorrect output list size " << outputs[0].size() | 
|  | << ". Output list size should be " << getSize() | 
|  | << ", same as size of the process group."; | 
|  | invalidArgument(ss.str()); | 
|  | } | 
|  |  | 
|  | const auto& options = inputs[0].options(); | 
|  | const auto& sizes = inputs[0].sizes(); | 
|  | assertTypeAndSizesMatch(invalidArgument, outputs[0], options, sizes); | 
|  | } else { | 
|  | if (outputs.size() != 0) { | 
|  | invalidArgument("requires empty output on non-root"); | 
|  | } | 
|  | } | 
|  |  | 
|  | const auto& device = inputs[0].device(); | 
|  | switch (device.type()) { | 
|  | case at::kCPU: | 
|  | break; | 
|  | case at::kCUDA: | 
|  | // If the user gave us a CUDA tensor then CUDA must be loaded. | 
|  | TORCH_INTERNAL_ASSERT(at::hasCUDA()); | 
|  | break; | 
|  | default: | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<AsyncGatherWork> work; | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | if (device.type() == at::kCPU) { | 
|  | work = c10::make_intrusive<AsyncGatherWork>( | 
|  | std::move(context), outputs, inputs, opts.rootRank, tag); | 
|  | } else if (device.type() == at::kCUDA) { | 
|  | work = c10::make_intrusive<AsyncGatherCUDAWork>( | 
|  | std::move(context), outputs, inputs, opts.rootRank, tag); | 
|  | } else { | 
|  | TORCH_CHECK(false, "Invalid backend"); | 
|  | } | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncScatterWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& outputs, | 
|  | std::vector<std::vector<at::Tensor>>& inputs, | 
|  | int root, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork( | 
|  | {outputs}, | 
|  | "gloo:scatter", | 
|  | inputs.size() > 0 | 
|  | ? c10::optional<std::vector<at::Tensor>>(inputs[0]) | 
|  | : c10::nullopt), | 
|  | context(context), | 
|  | outputs(outputs), | 
|  | inputs(inputs), | 
|  | root(root), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<at::Tensor> outputs; | 
|  | std::vector<std::vector<at::Tensor>> inputs; | 
|  | const int root; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void scatter( | 
|  | std::vector<at::Tensor>& outputs, | 
|  | std::vector<std::vector<at::Tensor>>& inputs) { | 
|  | const auto scalarType = outputs[0].scalar_type(); | 
|  | gloo::ScatterOptions opts(context); | 
|  | opts.setRoot(root); | 
|  | opts.setTag(tag); | 
|  |  | 
|  | // Set list of input tensors on root process | 
|  | if (context->rank == root) { | 
|  | GENERATE_ALL_TYPES(scalarType, setInputs, opts, inputs[0]); | 
|  | } | 
|  |  | 
|  | // Set single output tensor on all processes | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputs[0]); | 
|  | gloo::scatter(opts); | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | scatter(outputs, inputs); | 
|  | } | 
|  | }; | 
|  |  | 
|  | class AsyncScatterCUDAWork : public AsyncScatterWork { | 
|  | public: | 
|  | AsyncScatterCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<at::Tensor>& outputs, | 
|  | std::vector<std::vector<at::Tensor>>& inputs, | 
|  | int root, | 
|  | uint32_t tag) | 
|  | : AsyncScatterWork(context, outputs, inputs, root, tag) { | 
|  | initializeStreamsEvents(inputs, inputStreams, inputEvents); | 
|  | initializeStreamsEvents(outputs, outputStreams, outputEvents); | 
|  |  | 
|  | // Kick off copy from CUDA tensors to pinned CPU tensors. | 
|  | tmpInputs.resize(inputs.size()); | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | guard.reset_stream(inputStreams[i]); | 
|  | tmpInputs[i].reserve(inputs[i].size()); | 
|  | for (const auto j : c10::irange(inputs[i].size())) { | 
|  | tmpInputs[i].push_back( | 
|  | pinnedLike(inputs[i][j]).copy_(inputs[i][j], true)); | 
|  | } | 
|  | } | 
|  |  | 
|  | tmpOutputs.reserve(outputs.size()); | 
|  | for (auto& output : outputs) { | 
|  | tmpOutputs.push_back(pinnedLike(output)); | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operations. | 
|  | for (const auto i : c10::irange(inputs.size())) { | 
|  | inputStreams[i].synchronize(); | 
|  | } | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | outputStreams[i].synchronize(); | 
|  | } | 
|  |  | 
|  | // Run scatter on host side tensors. | 
|  | scatter(tmpOutputs, tmpInputs); | 
|  |  | 
|  | // Kick off copy back to the CUDA tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | guard.reset_stream(outputStreams[i]); | 
|  | outputs[i].copy_(tmpOutputs[i], /* non_blocking */ true); | 
|  | outputEvents[i].record(outputStreams[i]); | 
|  | } | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | for (const auto i : c10::irange(outputs.size())) { | 
|  | c10::Device device = outputs[i].device(); | 
|  | outputEvents[i].block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  | } | 
|  |  | 
|  | std::vector<at::Tensor> tmpOutputs; | 
|  | std::vector<c10::Stream> outputStreams; | 
|  | std::vector<c10::Event> outputEvents; | 
|  |  | 
|  | std::vector<std::vector<at::Tensor>> tmpInputs; | 
|  | std::vector<c10::Stream> inputStreams; | 
|  | std::vector<c10::Event> inputEvents; | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::scatter( | 
|  | std::vector<at::Tensor>& outputs, | 
|  | std::vector<std::vector<at::Tensor>>& inputs, | 
|  | const ScatterOptions& opts) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::scatter: " + msg); | 
|  | }; | 
|  |  | 
|  | assertRootRank(invalidArgument, opts.rootRank, size_); | 
|  | assertSingleElementOutput(invalidArgument, outputs); | 
|  | assertDense(invalidArgument, outputs); | 
|  |  | 
|  | if (getRank() == opts.rootRank) { | 
|  | if (inputs.size() != 1) { | 
|  | std::stringstream ss; | 
|  | ss << "requires a single-element input list containing a list with " | 
|  | << getSize() << " tensors"; | 
|  | invalidArgument(ss.str()); | 
|  | } else if (inputs[0].size() != static_cast<size_t>(getSize())) { | 
|  | std::stringstream ss; | 
|  | ss << "Incorrect input list size " << inputs[0].size() | 
|  | << ". Input list size should be " << getSize() | 
|  | << ", same as size of the process group."; | 
|  | invalidArgument(ss.str()); | 
|  | } | 
|  | const auto& options = outputs[0].options(); | 
|  | const auto& sizes = outputs[0].sizes(); | 
|  | assertTypeAndSizesMatch(invalidArgument, inputs[0], options, sizes); | 
|  | } else { | 
|  | if (inputs.size() != 0) { | 
|  | invalidArgument("requires empty input on non-root"); | 
|  | } | 
|  | } | 
|  |  | 
|  | const auto& device = outputs[0].device(); | 
|  | switch (device.type()) { | 
|  | case at::kCPU: | 
|  | break; | 
|  | case at::kCUDA: | 
|  | // If the user gave us a CUDA tensor then CUDA must be loaded. | 
|  | TORCH_INTERNAL_ASSERT(at::hasCUDA()); | 
|  | break; | 
|  | default: | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<AsyncScatterWork> work; | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | if (device.type() == at::kCPU) { | 
|  | work = c10::make_intrusive<AsyncScatterWork>( | 
|  | std::move(context), outputs, inputs, opts.rootRank, tag); | 
|  | } else if (device.type() == at::kCUDA) { | 
|  | work = c10::make_intrusive<AsyncScatterCUDAWork>( | 
|  | std::move(context), outputs, inputs, opts.rootRank, tag); | 
|  | } else { | 
|  | TORCH_CHECK(false, "Invalid backend"); | 
|  | } | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::reduce_scatter( | 
|  | std::vector<at::Tensor>& outputs, | 
|  | std::vector<std::vector<at::Tensor>>& inputs, | 
|  | const ReduceScatterOptions& opts) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo does not support reduce_scatter"); | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncAlltoallWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | at::Tensor& outputTensor, | 
|  | at::Tensor& inputTensor, | 
|  | std::vector<int64_t>& outputCounts, | 
|  | std::vector<int64_t>& inputCounts, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork( | 
|  | {{outputTensor}}, | 
|  | "gloo:all_to_all", | 
|  | c10::optional<std::vector<at::Tensor>>({inputTensor})), | 
|  | context(context), | 
|  | outputTensor(outputTensor), | 
|  | inputTensor(inputTensor), | 
|  | outputCounts(std::move(outputCounts)), | 
|  | inputCounts(std::move(inputCounts)), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | at::Tensor outputTensor; | 
|  | at::Tensor inputTensor; | 
|  | std::vector<int64_t> outputCounts; | 
|  | std::vector<int64_t> inputCounts; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void alltoall(at::Tensor& outputTensor, at::Tensor& inputTensor) { | 
|  | const auto scalarType = outputTensor.scalar_type(); | 
|  | if (outputCounts.size() == 0 && inputCounts.size() == 0) { | 
|  | // Gloo alltoall | 
|  | gloo::AlltoallOptions opts(context); | 
|  | opts.setTag(tag); | 
|  | GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor); | 
|  | gloo::alltoall(opts); | 
|  | } else { | 
|  | // Gloo alltoallv | 
|  | c10d::checkSplitSizes(inputCounts, inputTensor, context->size); | 
|  | c10d::checkSplitSizes(outputCounts, outputTensor, context->size); | 
|  | std::vector<int64_t> sendCounts(context->size); | 
|  | std::vector<int64_t> recvCounts(context->size); | 
|  | std::vector<int64_t> sendOffsets(context->size); | 
|  | std::vector<int64_t> recvOffsets(context->size); | 
|  | c10d::computeLengthsAndOffsets( | 
|  | inputCounts, inputTensor, &sendCounts, &sendOffsets); | 
|  | c10d::computeLengthsAndOffsets( | 
|  | outputCounts, outputTensor, &recvCounts, &recvOffsets); | 
|  | gloo::AlltoallvOptions opts(context); | 
|  | opts.setTag(tag); | 
|  | GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor, sendCounts); | 
|  | GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor, recvCounts); | 
|  | gloo::alltoallv(opts); | 
|  | } | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | alltoall(outputTensor, inputTensor); | 
|  | } | 
|  | }; | 
|  |  | 
|  | class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { | 
|  | public: | 
|  | AsyncAlltoallCUDAWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | at::Tensor& outputTensor, | 
|  | at::Tensor& inputTensor, | 
|  | std::vector<int64_t>& outputCounts, | 
|  | std::vector<int64_t>& inputCounts, | 
|  | uint32_t tag) | 
|  | : AsyncAlltoallWork( | 
|  | context, | 
|  | outputTensor, | 
|  | inputTensor, | 
|  | outputCounts, | 
|  | inputCounts, | 
|  | tag) { | 
|  | initializeStreamsEvents({inputTensor}, inputStreams, inputEvents); | 
|  | initializeStreamsEvents({outputTensor}, outputStreams, outputEvents); | 
|  |  | 
|  | // Kick off copy from CUDA tensors to pinned CPU tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | guard.reset_stream(inputStreams.front()); | 
|  | cpuInput = pinnedLike(inputTensor).copy_(inputTensor, true); | 
|  |  | 
|  | guard.reset_stream(outputStreams.front()); | 
|  | cpuOutput = pinnedLike(outputTensor); | 
|  | } | 
|  |  | 
|  | void run() override { | 
|  | // Synchronize with copy operations. | 
|  | inputStreams.front().synchronize(); | 
|  | outputStreams.front().synchronize(); | 
|  |  | 
|  | // Run alltoall on host side tensors. | 
|  | alltoall(cpuOutput, cpuInput); | 
|  |  | 
|  | // Kick off copy back to the CUDA tensors. | 
|  | c10::OptionalStreamGuard guard; | 
|  | guard.reset_stream(outputStreams.front()); | 
|  | outputTensor.copy_(cpuOutput, /* non_blocking */ true); | 
|  | outputEvents.front().record(outputStreams.front()); | 
|  | } | 
|  |  | 
|  | void synchronize() override { | 
|  | // Synchronize with the copy back to CUDA tensors. | 
|  | c10::Device device = outputTensor.device(); | 
|  | outputEvents.front().block( | 
|  | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); | 
|  | } | 
|  |  | 
|  | at::Tensor cpuOutput; | 
|  | std::vector<c10::Stream> outputStreams; | 
|  | std::vector<c10::Event> outputEvents; | 
|  |  | 
|  | at::Tensor cpuInput; | 
|  | std::vector<c10::Stream> inputStreams; | 
|  | std::vector<c10::Event> inputEvents; | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base( | 
|  | at::Tensor& outputTensor, | 
|  | at::Tensor& inputTensor, | 
|  | std::vector<int64_t>& outputCounts, | 
|  | std::vector<int64_t>& inputCounts, | 
|  | const AllToAllOptions& /* unused */) { | 
|  | static auto invalidArgument = [](const std::string& msg) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::alltoall_base: " + msg); | 
|  | }; | 
|  |  | 
|  | TORCH_CHECK( | 
|  | outputTensor.device() == inputTensor.device(), | 
|  | "output tensor and input tensor must be on the same type of device"); | 
|  | assertDense(invalidArgument, {outputTensor}); | 
|  | assertDense(invalidArgument, {inputTensor}); | 
|  |  | 
|  | const auto& device = outputTensor.device(); | 
|  | c10::intrusive_ptr<AsyncAlltoallWork> work; | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  |  | 
|  | if (device.type() == at::kCPU) { | 
|  | work = c10::make_intrusive<AsyncAlltoallWork>( | 
|  | std::move(context), | 
|  | outputTensor, | 
|  | inputTensor, | 
|  | outputCounts, | 
|  | inputCounts, | 
|  | tag); | 
|  | } else if (device.type() == at::kCUDA) { | 
|  | work = c10::make_intrusive<AsyncAlltoallCUDAWork>( | 
|  | std::move(context), | 
|  | outputTensor, | 
|  | inputTensor, | 
|  | outputCounts, | 
|  | inputCounts, | 
|  | tag); | 
|  | } else { | 
|  | invalidArgument(c10::str("unsupported device type ", device.type())); | 
|  | } | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | at::Tensor& checkSingleTensor(std::vector<at::Tensor>& tensors) { | 
|  | if (tensors.size() != 1) { | 
|  | TORCH_CHECK(false, "ProcessGroupGloo::send takes a single tensor"); | 
|  | } | 
|  | auto& tensor = tensors[0]; | 
|  | if (!tensor.is_contiguous()) { | 
|  | TORCH_CHECK(false, "input tensor has to be contiguous"); | 
|  | } | 
|  | if (tensor.is_sparse()) { | 
|  | TORCH_CHECK(false, "input tensor has to be dense"); | 
|  | } | 
|  | return tensor; | 
|  | } | 
|  |  | 
|  | uint32_t checkTag(int32_t tag) { | 
|  | TORCH_CHECK(tag >= 0, "Tag must be nonnegative"); | 
|  | return (uint32_t)tag; | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::send( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | int dstRank, | 
|  | int tag) { | 
|  | auto& tensor = checkSingleTensor(tensors); | 
|  | auto utag = checkTag(tag); | 
|  | auto ptr = tensor.data_ptr(); | 
|  | auto size = tensor.numel() * tensor.element_size(); | 
|  |  | 
|  | // Construct unbound buffer. | 
|  | auto context = getContext(tag); | 
|  | auto buf = context->createUnboundBuffer(ptr, size); | 
|  | buf->send(dstRank, utag); | 
|  |  | 
|  | // The work captures the tensor to prevent it being deallocated and | 
|  | // the unbound buffer to synchronize on completion of the send. | 
|  | return c10::make_intrusive<SendWork>(tensor, std::move(buf)); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::recv( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | int srcRank, | 
|  | int tag) { | 
|  | auto& tensor = checkSingleTensor(tensors); | 
|  | auto utag = checkTag(tag); | 
|  | auto ptr = tensor.data_ptr(); | 
|  | auto size = tensor.numel() * tensor.element_size(); | 
|  |  | 
|  | // Construct unbound buffer. | 
|  | auto context = getContext(tag); | 
|  | auto buf = context->createUnboundBuffer(ptr, size); | 
|  | buf->recv(srcRank, utag); | 
|  |  | 
|  | // The work captures the tensor to prevent it being deallocated and | 
|  | // the unbound buffer to synchronize on completion of the recv. | 
|  | return c10::make_intrusive<RecvWork>(tensor, std::move(buf), "gloo:recv"); | 
|  | } | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::recvAnysource( | 
|  | std::vector<at::Tensor>& tensors, | 
|  | int tag) { | 
|  | auto& tensor = checkSingleTensor(tensors); | 
|  | auto utag = checkTag(tag); | 
|  | auto ptr = tensor.data_ptr(); | 
|  | auto size = tensor.numel() * tensor.element_size(); | 
|  |  | 
|  | // Construct unbound buffer. | 
|  | auto context = getContext(tag); | 
|  | auto buf = context->createUnboundBuffer(ptr, size); | 
|  |  | 
|  | // Build list of ranks that this operation can recv from. In these | 
|  | // bindings we don't differentiate between ranks and can receive | 
|  | // from any other process in the group. | 
|  | std::vector<int> srcRanks; | 
|  | srcRanks.resize(size_); | 
|  | for (const auto i : c10::irange(size_)) { | 
|  | srcRanks.push_back(i); | 
|  | } | 
|  |  | 
|  | buf->recv(srcRanks, utag); | 
|  |  | 
|  | // The work captures the tensor to prevent it being deallocated and | 
|  | // the unbound buffer to synchronize on completion of the recv. | 
|  | return c10::make_intrusive<RecvWork>( | 
|  | tensor, std::move(buf), "gloo:recvAnySource"); | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { | 
|  | public: | 
|  | AsyncBarrierWork( | 
|  | const std::shared_ptr<gloo::Context>& context, | 
|  | std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork, | 
|  | uint32_t tag) | 
|  | : ProcessGroupGloo::AsyncWork({}, "gloo:barrier", c10::nullopt), | 
|  | context(context), | 
|  | priorWork(std::move(priorWork)), | 
|  | tag(tag) {} | 
|  |  | 
|  | std::shared_ptr<gloo::Context> context; | 
|  | std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork; | 
|  | const uint32_t tag; | 
|  |  | 
|  | void run() override { | 
|  | // Wait on prior work to complete | 
|  | for (auto& weakWork : priorWork) { | 
|  | auto work = weakWork.lock(); | 
|  | if (work) { | 
|  | work->wait(); | 
|  | } | 
|  | } | 
|  |  | 
|  | gloo::BarrierOptions opts(context); | 
|  | opts.setTag(tag); | 
|  | gloo::barrier(opts); | 
|  | } | 
|  | }; | 
|  |  | 
|  | } // namespace | 
|  |  | 
|  | c10::intrusive_ptr<Work> ProcessGroupGloo::barrier(const BarrierOptions& opts) { | 
|  | std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork; | 
|  |  | 
|  | // Snapshot all in progress and pending work as weak_ptr. | 
|  | // When executing a barrier, we need to ensure that all prior work | 
|  | // has completed before completing itself. | 
|  | { | 
|  | std::unique_lock<std::mutex> lock(workMutex_); | 
|  | priorWork.insert( | 
|  | priorWork.end(), workInProgress_.begin(), workInProgress_.end()); | 
|  | priorWork.insert(priorWork.end(), workQueue_.begin(), workQueue_.end()); | 
|  | } | 
|  |  | 
|  | auto tag = nextTag(); | 
|  | auto context = getContext(tag); | 
|  | auto work = c10::make_intrusive<AsyncBarrierWork>( | 
|  | std::move(context), std::move(priorWork), tag); | 
|  | enqueue(work); | 
|  | return work; | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::monitoredBarrier( | 
|  | const BarrierOptions& opts, | 
|  | bool waitAllRanks) { | 
|  | C10_LOG_API_USAGE_ONCE("torch.distributed.monitored_barrier"); | 
|  | // Use default timeout if no timeout was specified. | 
|  | auto monitoredBarrierTimeout = | 
|  | (opts.timeout == kUnsetTimeout) ? this->options_->timeout : opts.timeout; | 
|  | auto rank = this->getRank(); | 
|  | auto t1 = nextTag(); | 
|  | auto t2 = nextTag(); | 
|  | std::vector<at::Tensor> commTensor = {at::tensor({rank})}; | 
|  | // only enforce timeout on rank 0. This is so that other ranks aren't timed | 
|  | // out first, bringing down the job without reporting which rank timed out. | 
|  | if (rank != 0) { | 
|  | auto sendWork = send(commTensor, 0, t1); | 
|  | auto recvWork = recv(commTensor, 0, t2); | 
|  | try { | 
|  | sendWork->wait(); | 
|  | recvWork->wait(); | 
|  | } catch (const std::exception& e) { | 
|  | const std::string error = c10::str( | 
|  | "Rank ", | 
|  | rank, | 
|  | " successfully reached monitoredBarrier, but received errors while waiting", | 
|  | " for send/recv from rank 0. Please check rank 0 logs for faulty rank."); | 
|  | logAndThrow( | 
|  | error, c10::str(error, "\n Original exception: \n", e.what())); | 
|  | } | 
|  | return; | 
|  | } | 
|  | auto startTime = std::chrono::steady_clock::now(); | 
|  | auto worldSize = this->getSize(); | 
|  | // Mappings of rank to recvWork/sendWork respectively. | 
|  | std::map<int, c10::intrusive_ptr<Work>> recvWorkMap; | 
|  | std::map<int, c10::intrusive_ptr<Work>> sendWorkMap; | 
|  | // Kick off recvWork and wait to unblock sendWork->wait() from non-zero ranks. | 
|  | // Failed/hanging ranks will not ack this call, letting rank 0 know about the | 
|  | // failure. | 
|  | for (const auto dstRank : c10::irange(1, worldSize)) { | 
|  | recvWorkMap.insert({dstRank, recv(commTensor, dstRank, t1)}); | 
|  | } | 
|  |  | 
|  | auto waitLoop = [&](const std::map<int, c10::intrusive_ptr<Work>>& works) { | 
|  | std::vector<int> processedRanks; | 
|  | for (auto& work : works) { | 
|  | bool rankResponded = false; | 
|  | try { | 
|  | // Note: if waitAllRanks=false, we recompute the time remaining in | 
|  | // barrier and use this recomputed time in wait(). However, if | 
|  | // waitAllRanks=true, we use the original timeout, since if we use | 
|  | // up the entire timeout waiting for response from rank n, then we | 
|  | // won't have any timeout left to query ranks beginning with n + 1. | 
|  | auto remainingTime = | 
|  | getRemainingTime(startTime, monitoredBarrierTimeout, waitAllRanks); | 
|  | if (!waitAllRanks) { | 
|  | checkRemainingTime( | 
|  | monitoredBarrierTimeout, remainingTime, processedRanks, rank); | 
|  | } | 
|  | work.second->wait(remainingTime); | 
|  | rankResponded = true; | 
|  | } catch (const std::exception& e) { | 
|  | const std::string error = c10::str( | 
|  | "[Rank 0]: Rank ", | 
|  | work.first, | 
|  | " failed to pass monitoredBarrier in ", | 
|  | monitoredBarrierTimeout.count(), | 
|  | " ms"); | 
|  | if (waitAllRanks) { | 
|  | LOG(ERROR) << error; | 
|  | } else { | 
|  | logAndThrow( | 
|  | error, c10::str(error, "\n Original exception: \n", e.what())); | 
|  | } | 
|  | } | 
|  | if (rankResponded) { | 
|  | processedRanks.push_back(work.first); | 
|  | } | 
|  | } | 
|  | // If we are collecting all failed ranks, check if we need to throw if | 
|  | // some ranks have not responded. | 
|  | // Ensure all ranks from 1, ... WORLD_SIZE -1 have been successfully | 
|  | // processed. | 
|  | auto rankFailure = (processedRanks.size() != size_ - 1); | 
|  | if (waitAllRanks && rankFailure) { | 
|  | std::vector<int> failedRanks; | 
|  | for (const auto i : c10::irange(1, size_)) { | 
|  | if (std::find(processedRanks.begin(), processedRanks.end(), i) == | 
|  | processedRanks.end()) { | 
|  | failedRanks.push_back(i); | 
|  | } | 
|  | } | 
|  |  | 
|  | TORCH_INTERNAL_ASSERT(!failedRanks.empty()); | 
|  | const std::string ranksStr = c10::Join(", ", failedRanks); | 
|  | const std::string error = c10::str( | 
|  | "[Rank 0]: Ranks ", | 
|  | ranksStr, | 
|  | " failed to pass monitoredBarrier in ", | 
|  | monitoredBarrierTimeout.count(), | 
|  | " ms"); | 
|  | logAndThrow(error, error); | 
|  | } | 
|  | }; | 
|  |  | 
|  | waitLoop(recvWorkMap); | 
|  | // If we've reached here successfully, this means all ranks have acked in | 
|  | // monitoredBarrier. Unblock all ranks now by responding to their recv(). This | 
|  | // ensures that this is a true barrier in that all ranks  exit it successfully | 
|  | // or none of them do. | 
|  | for (const auto dstRank : c10::irange(1, worldSize)) { | 
|  | sendWorkMap.insert({dstRank, send(commTensor, dstRank, t2)}); | 
|  | } | 
|  |  | 
|  | waitLoop(sendWorkMap); | 
|  | } | 
|  |  | 
|  | void ProcessGroupGloo::setSequenceNumberForGroup() { | 
|  | if (rank_ == 0) { | 
|  | // Create and broadcast sequence number | 
|  | auto seq = 1 + rand(); | 
|  | sequenceNum_ = c10d::SequenceNum(seq); | 
|  | std::vector<char> values = c10d::toVec<char>(seq, kBytes); | 
|  | store_->set(kSeqNumStoreKey, values); | 
|  | } else { | 
|  | // Read rank 0's sequence number from store. | 
|  | sequenceNum_ = c10d::SequenceNum(); | 
|  | store_->wait({kSeqNumStoreKey}, options_->timeout); | 
|  | std::vector<char> values = store_->get(kSeqNumStoreKey); | 
|  | uint64_t num = c10d::fromVec<char>(values); | 
|  | sequenceNum_->set(num); | 
|  | } | 
|  | } | 
|  |  | 
|  | uint64_t ProcessGroupGloo::getSequenceNumberForGroup() { | 
|  | if (sequenceNum_ == c10::nullopt) { | 
|  | return 0; | 
|  | } | 
|  | return sequenceNum_->get(); | 
|  | } | 
|  |  | 
|  | } // namespace c10d | 
|  |  | 
|  | #endif // USE_C10D_GLOO |