| #include <chrono> |
| #include <filesystem> |
| #include <fstream> |
| #include <thread> |
| |
| #include <c10/util/irange.h> |
| #include <torch/csrc/cuda/nccl.h> |
| #include <torch/csrc/distributed/c10d/FileStore.hpp> |
| #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> |
| #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> |
| #include "CUDATest.hpp" |
| #include "TestUtils.hpp" |
| |
| #include <gtest/gtest.h> |
| |
| using namespace c10d::test; |
| |
| constexpr int kNcclErrorHandlingVersion = 2400; |
| |
| class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { |
| public: |
| WorkNCCLSimulateErrors( |
| at::Device& device, |
| bool simulate_error, |
| int rank, |
| c10d::OpType opType, |
| uint64_t seq) |
| : WorkNCCL(device, rank, opType, seq), simulateError_(simulate_error) {} |
| |
| std::exception_ptr checkForNCCLErrors() override { |
| if (simulateError_) { |
| return std::make_exception_ptr(std::runtime_error("Error")); |
| } |
| return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(); |
| } |
| |
| private: |
| bool simulateError_; |
| }; |
| |
| class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { |
| public: |
| ProcessGroupNCCLSimulateErrors( |
| const c10::intrusive_ptr<c10d::Store>& store, |
| int rank, |
| int size, |
| c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts) |
| : ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {} |
| |
| std::exception_ptr checkForNCCLErrors( |
| std::shared_ptr<c10d::NCCLComm>& ncclComm) override { |
| if (simulateError_) { |
| return std::make_exception_ptr(std::runtime_error("Error")); |
| } |
| return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm); |
| } |
| |
| std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() { |
| return std::chrono::milliseconds( |
| ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis); |
| } |
| |
| c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork( |
| at::Device& device, |
| int rank, |
| c10d::OpType opType, |
| const char* profilingTitle, |
| const std::vector<at::Tensor>& inputs = {}, |
| const std::vector<at::Tensor>& outputs = {}, |
| bool record = false) override { |
| return c10::make_intrusive<WorkNCCLSimulateErrors>( |
| device, simulateError_, rank, opType, seq_); |
| } |
| |
| size_t getNCCLCommCacheSize() { |
| return devNCCLCommMap_.size(); |
| } |
| |
| void simulateError() { |
| simulateError_ = true; |
| } |
| |
| void resetError() { |
| simulateError_ = false; |
| } |
| |
| private: |
| bool simulateError_; |
| }; |
| |
| class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { |
| public: |
| WorkNCCLTimedoutErrors( |
| at::Device& device, |
| bool set_timedout_error, |
| int rank, |
| c10d::OpType opType, |
| uint64_t seq) |
| : WorkNCCL(device, rank, opType, seq), |
| setTimedoutError_(set_timedout_error) {} |
| |
| private: |
| bool isCompleted() override { |
| if (setTimedoutError_) { |
| return false; |
| } |
| return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted(); |
| } |
| |
| private: |
| bool setTimedoutError_; |
| }; |
| |
| class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { |
| public: |
| ProcessGroupNCCLTimedOutErrors( |
| const c10::intrusive_ptr<c10d::Store>& store, |
| int rank, |
| int size, |
| c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts) |
| : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), |
| watchDogDebugInfoFinished_(false), |
| setTimedoutError_(false) {} |
| |
| c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork( |
| at::Device& device, |
| int rank, |
| c10d::OpType opType, |
| const char* profilingTitle, |
| const std::vector<at::Tensor>& inputs = {}, |
| const std::vector<at::Tensor>& outputs = {}, |
| bool record = false) override { |
| return c10::make_intrusive<WorkNCCLTimedoutErrors>( |
| device, setTimedoutError_, rank, opType, seq_); |
| } |
| |
| void setTimedoutError() { |
| setTimedoutError_ = true; |
| } |
| |
| void resetTimedoutError() { |
| setTimedoutError_ = false; |
| } |
| |
| bool getWatchDogDebugInfoFinishedFlag() { |
| return watchDogDebugInfoFinished_; |
| } |
| |
| // In the constructor of ProcessGroupNCCL. We don't allow the watchdog thread |
| // to run any handling or desync report when the main thread is block wait. |
| // Even if users set handling and turn on desyncDebug flag, they will get |
| // reset. For the ease of unit test, we want the main thread to be block wait, |
| // so we have this hack to manually set the desync debug flag after PG |
| // creation. |
| void forceSetDesyncDebugFlag() { |
| desyncDebug_ = true; |
| } |
| |
| protected: |
| std::string getNCCLWatchdogDebugInfo() override { |
| LOG(INFO) << "overridden getNCCLWatchdogDebugInfo called"; |
| watchDogDebugInfoFinished_ = true; |
| return ""; |
| } |
| bool watchDogDebugInfoFinished_; |
| |
| private: |
| bool setTimedoutError_; |
| }; |
| |
| class ProcessGroupNCCLNoHeartbeatCaught |
| : public ProcessGroupNCCLTimedOutErrors { |
| public: |
| ProcessGroupNCCLNoHeartbeatCaught( |
| const c10::intrusive_ptr<c10d::Store>& store, |
| int rank, |
| int size, |
| c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts) |
| : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts), |
| hasMonitorThreadCaughtError_(false) {} |
| |
| std::mutex& getWatchdogMutex() { |
| return workMetaListMutex_; |
| } |
| |
| bool getErrorCaughtFlag() { |
| return hasMonitorThreadCaughtError_; |
| } |
| |
| void forceTryWriteDebugInfo() { |
| std::future<bool> asyncDebugDump = std::async( |
| std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); |
| asyncDebugDump.wait(); |
| } |
| |
| protected: |
| // Override the heartbeat monitor function to make sure that we capture |
| // the exception in the monitor thread because we cannot try-catch it in |
| // the main thread and we set a flag for the main thread to check. |
| void heartbeatMonitor() override { |
| try { |
| c10d::ProcessGroupNCCL::heartbeatMonitor(); |
| } catch (std::runtime_error& e) { |
| hasMonitorThreadCaughtError_ = true; |
| } |
| } |
| |
| // It's really hard to unit test std::abort. So we override it instead. |
| // Commented this override, we do see process aborted with core dump without |
| // this override. |
| void terminateProcess(std::string errMsg) override { |
| throw std::runtime_error(errMsg); |
| } |
| |
| bool hasMonitorThreadCaughtError_; |
| }; |
| |
| class ProcessGroupNCCLDebugInfoStuck |
| : public ProcessGroupNCCLNoHeartbeatCaught { |
| public: |
| ProcessGroupNCCLDebugInfoStuck( |
| const c10::intrusive_ptr<c10d::Store>& store, |
| int rank, |
| int size, |
| c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts) |
| : ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, opts) {} |
| |
| protected: |
| // Override the heartbeat monitor function to set a long timeout to mimic the |
| // stuck in getting debug info. |
| std::string getNCCLWatchdogDebugInfo() override { |
| std::this_thread::sleep_for( |
| std::chrono::seconds(heartbeatTimeoutInSec_ * 20)); |
| watchDogDebugInfoFinished_ = true; |
| return ""; |
| } |
| }; |
| |
| class ProcessGroupNCCLErrorsTest : public ::testing::Test { |
| protected: |
| bool skipTest() { |
| if (cudaNumDevices() == 0) { |
| LOG(INFO) << "Skipping test since CUDA is not available"; |
| return true; |
| } |
| #ifdef USE_C10D_NCCL |
| if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) { |
| LOG(INFO) << "Skipping test since NCCL version is too old"; |
| return true; |
| } |
| #endif |
| return false; |
| } |
| |
| void SetUp() override { |
| // Enable LOG(INFO) messages. |
| c10::initLogging(); |
| // Need to have this check for at SetUp to make sure we only run the test -- |
| // including the init -- when there are GPUs available. |
| if (skipTest()) { |
| GTEST_SKIP() << "Skipping ProcessGroupNCCLErrorsTest because system " |
| << "requirement is not met (no CUDA or GPU)."; |
| } |
| |
| size_t numDevices = 1; // One device per rank (thread) |
| TemporaryFile file; |
| store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1); |
| |
| tensors_.resize(numDevices); |
| tensors_[0] = at::empty({3, 3}, at::kCUDA); |
| } |
| |
| void TearDown() override { |
| ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); |
| } |
| |
| std::vector<at::Tensor> tensors_; |
| c10::intrusive_ptr<::c10d::FileStore> store_; |
| }; |
| |
| TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { |
| ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); |
| auto options = c10d::ProcessGroupNCCL::Options::create(); |
| options->timeout = std::chrono::milliseconds(1000); |
| ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); |
| |
| auto work = pg.allreduce(tensors_); |
| work->wait(); |
| EXPECT_EQ(1, pg.getNCCLCommCacheSize()); |
| |
| // Now run all reduce with errors. |
| pg.simulateError(); |
| work = pg.allreduce(tensors_); |
| EXPECT_THROW(work->wait(), std::runtime_error); |
| |
| // Verify the work item failed. |
| EXPECT_TRUE(work->isCompleted()); |
| EXPECT_THROW(work->wait(), std::runtime_error); |
| |
| // Communicators might be aborted here, further operations would fail. |
| } |
| |
| TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { |
| ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); |
| auto options = c10d::ProcessGroupNCCL::Options::create(); |
| options->timeout = std::chrono::milliseconds(3000); |
| ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options); |
| |
| auto work = pg.allreduce(tensors_); |
| work->wait(); |
| EXPECT_EQ(1, pg.getNCCLCommCacheSize()); |
| |
| // Now run all reduce with errors. |
| pg.setTimedoutError(); |
| work = pg.allreduce(tensors_); |
| EXPECT_THROW(work->wait(), c10::DistBackendError); |
| |
| // Communicators might be aborted here, further operations would fail. |
| } |
| |
| TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { |
| auto options = c10d::ProcessGroupNCCL::Options::create(); |
| options->timeout = std::chrono::milliseconds(3000); |
| ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); |
| |
| auto work = pg.allreduce(tensors_); |
| pg.barrier()->wait(); |
| EXPECT_EQ(1, pg.getNCCLCommCacheSize()); |
| |
| // Now run all reduce with errors. |
| pg.simulateError(); |
| work = pg.allreduce(tensors_); |
| |
| // Should not throw exceptions. |
| work->wait(); |
| pg.barrier()->wait(); |
| |
| EXPECT_TRUE(work->isCompleted()); |
| // Communicators might be aborted here, further operations would fail. |
| } |
| |
| // Function to read what we wrote to the local disk for validation. |
| std::string readTraceFromFile(const std::string& filename, size_t size) { |
| std::ifstream file(filename, std::ios::binary); |
| // Read the strings from the file |
| if (file) { // While the file stream is in good state |
| std::string str(size, '\0'); |
| file.read(&str[0], size); |
| if (file) { |
| return str; |
| } |
| } |
| return ""; |
| } |
| |
| // Extend the nested class outside the parent class |
| class TestDebugInfoWriter : public c10d::DebugInfoWriter { |
| public: |
| TestDebugInfoWriter(std::string namePrefix) |
| : DebugInfoWriter(namePrefix, 0) {} |
| |
| void write(const std::string& ncclTrace) override { |
| traces_.assign(ncclTrace.begin(), ncclTrace.end()); |
| c10d::DebugInfoWriter::write(ncclTrace); |
| } |
| |
| std::vector<uint8_t>& getTraces() { |
| return traces_; |
| } |
| |
| private: |
| std::vector<uint8_t> traces_; |
| }; |
| |
| TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { |
| int heartBeatIntervalInSec = 2; |
| std::string timeInterval = std::to_string(heartBeatIntervalInSec); |
| ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); |
| ASSERT_TRUE( |
| setenv( |
| c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(), |
| timeInterval.c_str(), |
| 1) == 0); |
| ASSERT_TRUE( |
| setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0); |
| auto tempFilename = c10::str( |
| std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_"); |
| ASSERT_TRUE( |
| setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0); |
| // Enable nccl flight recorder. |
| ASSERT_TRUE(setenv("TORCH_NCCL_TRACE_BUFFER_SIZE", "10", 1) == 0); |
| auto options = c10d::ProcessGroupNCCL::Options::create(); |
| // Set a long watchdog timeout, so that we have enough time to lock the |
| // watchdog and let the heartbeat monitor thread to kick in. |
| options->timeout = std::chrono::milliseconds(30000); |
| ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options); |
| // The storer here is very similar to the fallback storer. |
| // The only difference is that we are storing traces also in memory for |
| // validation. |
| std::string fileNamePrefix = c10d::getCvarString( |
| {"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_"); |
| std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr = |
| std::make_unique<TestDebugInfoWriter>(fileNamePrefix); |
| std::vector<uint8_t>& traces = wrterForTestPtr->getTraces(); |
| c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr)); |
| |
| // Normal collective case. |
| auto work = pg.allreduce(tensors_); |
| work->wait(); |
| |
| work = pg.allreduce(tensors_); |
| { |
| // Now run all reduce with errors. |
| std::lock_guard<std::mutex> lock(pg.getWatchdogMutex()); |
| LOG(INFO) << "Lock watchdog thread."; |
| // Wait long enough before monitor thread throws exceptions. |
| std::this_thread::sleep_for( |
| std::chrono::seconds(heartBeatIntervalInSec * 3)); |
| // Check the monitoring thread launched and exception thrown. |
| EXPECT_TRUE(pg.getErrorCaughtFlag()); |
| } |
| work->wait(); |
| EXPECT_TRUE(traces.size() > 0); |
| auto filename = c10::str(tempFilename, 0); |
| auto traceFromStorage = readTraceFromFile(filename, traces.size()); |
| // Check the traces read from storage match with the original nccl trace. |
| EXPECT_TRUE(traceFromStorage == std::string(traces.begin(), traces.end())); |
| std::filesystem::remove(filename); |
| } |
| |
| class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest { |
| protected: |
| void SetUp() override { |
| // TODO (kwen2501) |
| GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; " |
| << "will rewrite them after refactoring Work queues."; |
| ProcessGroupNCCLErrorsTest::SetUp(); |
| std::string timeInterval = std::to_string(heartBeatIntervalInSec); |
| ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); |
| ASSERT_TRUE( |
| setenv( |
| c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(), |
| timeInterval.c_str(), |
| 1) == 0); |
| ASSERT_TRUE( |
| setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0); |
| ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0); |
| // We cannot capture the exception thrown in watchdog thread without making |
| // lots of changes to the code. So we don't let the watchdog throw |
| // exception. |
| ASSERT_TRUE( |
| setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0); |
| options_ = c10d::ProcessGroupNCCL::Options::create(); |
| // Set a super short watchdog timeout. |
| options_->timeout = std::chrono::milliseconds(100); |
| } |
| |
| void watchdogTimeoutTestCommon( |
| ProcessGroupNCCLNoHeartbeatCaught& pg, |
| int multiplier) { |
| pg.forceSetDesyncDebugFlag(); |
| pg.setTimedoutError(); |
| auto work = pg.allreduce(tensors_); |
| std::this_thread::sleep_for( |
| std::chrono::seconds(heartBeatIntervalInSec * multiplier)); |
| EXPECT_THROW(work->wait(), c10::DistBackendError); |
| } |
| |
| const int heartBeatIntervalInSec = 2; |
| c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> options_; |
| }; |
| |
| TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoFinished) { |
| ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options_); |
| // Write debug info will lead to watchdog thread to wait for 30 seconds. |
| // And this is hard to override, so we just call it before hand. Otherwise, |
| // we need to set a long heartbeat timeout which will make the test way |
| // slower. |
| pg.forceTryWriteDebugInfo(); |
| watchdogTimeoutTestCommon(pg, 2); |
| |
| // The flag is true shows that the heartbeat monitor thread does not kill |
| // the watchdog thread when it is getting debug info such as desync debug |
| // info. |
| EXPECT_TRUE(pg.getWatchDogDebugInfoFinishedFlag()); |
| // The flag is false shows that the heartbeat monitor thread does not |
| // trigger process abort if getting debug info and destroy PG is fast. |
| EXPECT_FALSE(pg.getErrorCaughtFlag()); |
| |
| // Communicators might be aborted here, further operations would fail. |
| } |
| |
| TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoStuck) { |
| ProcessGroupNCCLDebugInfoStuck pg(store_, 0, 1, options_); |
| // Need to keep main thread sleep longer so that we can let heartbeat monitor |
| // thread to finish the extra wait and flip the flag. |
| watchdogTimeoutTestCommon(pg, 4); |
| // The flag is false shows that we get stuck in getting debug info such as |
| // desync debug info in the watchdog thread. |
| EXPECT_FALSE(pg.getWatchDogDebugInfoFinishedFlag()); |
| // The flag is true shows that the heartbeat monitor thread does trigger |
| // process abort if getting debug info gets stuck. |
| EXPECT_TRUE(pg.getErrorCaughtFlag()); |
| |
| // Communicators might be aborted here, further operations would fail. |
| } |