blob: 991a8c5dc7060507001991eb2316863f577ccc7f [file] [log] [blame]
#include <chrono>
#include <filesystem>
#include <fstream>
#include <thread>
#include <c10/util/irange.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/FileStore.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include "CUDATest.hpp"
#include "TestUtils.hpp"
#include <gtest/gtest.h>
using namespace c10d::test;
constexpr int kNcclErrorHandlingVersion = 2400;
class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLSimulateErrors(
at::Device& device,
bool simulate_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(device, rank, opType, seq), simulateError_(simulate_error) {}
std::exception_ptr checkForNCCLErrors() override {
if (simulateError_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors();
}
private:
bool simulateError_;
};
class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
public:
ProcessGroupNCCLSimulateErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {}
std::exception_ptr checkForNCCLErrors(
std::shared_ptr<c10d::NCCLComm>& ncclComm) override {
if (simulateError_) {
return std::make_exception_ptr(std::runtime_error("Error"));
}
return c10d::ProcessGroupNCCL::checkForNCCLErrors(ncclComm);
}
std::chrono::duration<int64_t, std::milli> getWatchdogSleepInterval() {
return std::chrono::milliseconds(
ProcessGroupNCCLSimulateErrors::kWatchdogThreadSleepMillis);
}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
at::Device& device,
int rank,
c10d::OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs = {},
const std::vector<at::Tensor>& outputs = {},
bool record = false) override {
return c10::make_intrusive<WorkNCCLSimulateErrors>(
device, simulateError_, rank, opType, seq_);
}
size_t getNCCLCommCacheSize() {
return devNCCLCommMap_.size();
}
void simulateError() {
simulateError_ = true;
}
void resetError() {
simulateError_ = false;
}
private:
bool simulateError_;
};
class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
public:
WorkNCCLTimedoutErrors(
at::Device& device,
bool set_timedout_error,
int rank,
c10d::OpType opType,
uint64_t seq)
: WorkNCCL(device, rank, opType, seq),
setTimedoutError_(set_timedout_error) {}
private:
bool isCompleted() override {
if (setTimedoutError_) {
return false;
}
return c10d::ProcessGroupNCCL::WorkNCCL::isCompleted();
}
private:
bool setTimedoutError_;
};
class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
public:
ProcessGroupNCCLTimedOutErrors(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
watchDogDebugInfoFinished_(false),
setTimedoutError_(false) {}
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
at::Device& device,
int rank,
c10d::OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs = {},
const std::vector<at::Tensor>& outputs = {},
bool record = false) override {
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
device, setTimedoutError_, rank, opType, seq_);
}
void setTimedoutError() {
setTimedoutError_ = true;
}
void resetTimedoutError() {
setTimedoutError_ = false;
}
bool getWatchDogDebugInfoFinishedFlag() {
return watchDogDebugInfoFinished_;
}
// In the constructor of ProcessGroupNCCL. We don't allow the watchdog thread
// to run any handling or desync report when the main thread is block wait.
// Even if users set handling and turn on desyncDebug flag, they will get
// reset. For the ease of unit test, we want the main thread to be block wait,
// so we have this hack to manually set the desync debug flag after PG
// creation.
void forceSetDesyncDebugFlag() {
desyncDebug_ = true;
}
protected:
std::string getNCCLWatchdogDebugInfo() override {
LOG(INFO) << "overridden getNCCLWatchdogDebugInfo called";
watchDogDebugInfoFinished_ = true;
return "";
}
bool watchDogDebugInfoFinished_;
private:
bool setTimedoutError_;
};
class ProcessGroupNCCLNoHeartbeatCaught
: public ProcessGroupNCCLTimedOutErrors {
public:
ProcessGroupNCCLNoHeartbeatCaught(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLTimedOutErrors(store, rank, size, opts),
hasMonitorThreadCaughtError_(false) {}
std::mutex& getWatchdogMutex() {
return workMetaListMutex_;
}
bool getErrorCaughtFlag() {
return hasMonitorThreadCaughtError_;
}
void forceTryWriteDebugInfo() {
std::future<bool> asyncDebugDump = std::async(
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
asyncDebugDump.wait();
}
protected:
// Override the heartbeat monitor function to make sure that we capture
// the exception in the monitor thread because we cannot try-catch it in
// the main thread and we set a flag for the main thread to check.
void heartbeatMonitor() override {
try {
c10d::ProcessGroupNCCL::heartbeatMonitor();
} catch (std::runtime_error& e) {
hasMonitorThreadCaughtError_ = true;
}
}
// It's really hard to unit test std::abort. So we override it instead.
// Commented this override, we do see process aborted with core dump without
// this override.
void terminateProcess(std::string errMsg) override {
throw std::runtime_error(errMsg);
}
bool hasMonitorThreadCaughtError_;
};
class ProcessGroupNCCLDebugInfoStuck
: public ProcessGroupNCCLNoHeartbeatCaught {
public:
ProcessGroupNCCLDebugInfoStuck(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
: ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, opts) {}
protected:
// Override the heartbeat monitor function to set a long timeout to mimic the
// stuck in getting debug info.
std::string getNCCLWatchdogDebugInfo() override {
std::this_thread::sleep_for(
std::chrono::seconds(heartbeatTimeoutInSec_ * 20));
watchDogDebugInfoFinished_ = true;
return "";
}
};
class ProcessGroupNCCLErrorsTest : public ::testing::Test {
protected:
bool skipTest() {
if (cudaNumDevices() == 0) {
LOG(INFO) << "Skipping test since CUDA is not available";
return true;
}
#ifdef USE_C10D_NCCL
if (torch::cuda::nccl::version() < kNcclErrorHandlingVersion) {
LOG(INFO) << "Skipping test since NCCL version is too old";
return true;
}
#endif
return false;
}
void SetUp() override {
// Enable LOG(INFO) messages.
c10::initLogging();
// Need to have this check for at SetUp to make sure we only run the test --
// including the init -- when there are GPUs available.
if (skipTest()) {
GTEST_SKIP() << "Skipping ProcessGroupNCCLErrorsTest because system "
<< "requirement is not met (no CUDA or GPU).";
}
size_t numDevices = 1; // One device per rank (thread)
TemporaryFile file;
store_ = c10::make_intrusive<::c10d::FileStore>(file.path, 1);
tensors_.resize(numDevices);
tensors_[0] = at::empty({3, 3}, at::kCUDA);
}
void TearDown() override {
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);
}
std::vector<at::Tensor> tensors_;
c10::intrusive_ptr<::c10d::FileStore> store_;
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(1000);
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulateError();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), std::runtime_error);
// Verify the work item failed.
EXPECT_TRUE(work->isCompleted());
EXPECT_THROW(work->wait(), std::runtime_error);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLTimedOutErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
work->wait();
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.setTimedoutError();
work = pg.allreduce(tensors_);
EXPECT_THROW(work->wait(), c10::DistBackendError);
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
auto options = c10d::ProcessGroupNCCL::Options::create();
options->timeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options);
auto work = pg.allreduce(tensors_);
pg.barrier()->wait();
EXPECT_EQ(1, pg.getNCCLCommCacheSize());
// Now run all reduce with errors.
pg.simulateError();
work = pg.allreduce(tensors_);
// Should not throw exceptions.
work->wait();
pg.barrier()->wait();
EXPECT_TRUE(work->isCompleted());
// Communicators might be aborted here, further operations would fail.
}
// Function to read what we wrote to the local disk for validation.
std::string readTraceFromFile(const std::string& filename, size_t size) {
std::ifstream file(filename, std::ios::binary);
// Read the strings from the file
if (file) { // While the file stream is in good state
std::string str(size, '\0');
file.read(&str[0], size);
if (file) {
return str;
}
}
return "";
}
// Extend the nested class outside the parent class
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
public:
TestDebugInfoWriter(std::string namePrefix)
: DebugInfoWriter(namePrefix, 0) {}
void write(const std::string& ncclTrace) override {
traces_.assign(ncclTrace.begin(), ncclTrace.end());
c10d::DebugInfoWriter::write(ncclTrace);
}
std::vector<uint8_t>& getTraces() {
return traces_;
}
private:
std::vector<uint8_t> traces_;
};
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
int heartBeatIntervalInSec = 2;
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
ASSERT_TRUE(
setenv(
c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
timeInterval.c_str(),
1) == 0);
ASSERT_TRUE(
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
auto tempFilename = c10::str(
std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_");
ASSERT_TRUE(
setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0);
// Enable nccl flight recorder.
ASSERT_TRUE(setenv("TORCH_NCCL_TRACE_BUFFER_SIZE", "10", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
// Set a long watchdog timeout, so that we have enough time to lock the
// watchdog and let the heartbeat monitor thread to kick in.
options->timeout = std::chrono::milliseconds(30000);
ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options);
// The storer here is very similar to the fallback storer.
// The only difference is that we are storing traces also in memory for
// validation.
std::string fileNamePrefix = c10d::getCvarString(
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
// Normal collective case.
auto work = pg.allreduce(tensors_);
work->wait();
work = pg.allreduce(tensors_);
{
// Now run all reduce with errors.
std::lock_guard<std::mutex> lock(pg.getWatchdogMutex());
LOG(INFO) << "Lock watchdog thread.";
// Wait long enough before monitor thread throws exceptions.
std::this_thread::sleep_for(
std::chrono::seconds(heartBeatIntervalInSec * 3));
// Check the monitoring thread launched and exception thrown.
EXPECT_TRUE(pg.getErrorCaughtFlag());
}
work->wait();
EXPECT_TRUE(traces.size() > 0);
auto filename = c10::str(tempFilename, 0);
auto traceFromStorage = readTraceFromFile(filename, traces.size());
// Check the traces read from storage match with the original nccl trace.
EXPECT_TRUE(traceFromStorage == std::string(traces.begin(), traces.end()));
std::filesystem::remove(filename);
}
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
protected:
void SetUp() override {
// TODO (kwen2501)
GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
<< "will rewrite them after refactoring Work queues.";
ProcessGroupNCCLErrorsTest::SetUp();
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
ASSERT_TRUE(
setenv(
c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(),
timeInterval.c_str(),
1) == 0);
ASSERT_TRUE(
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_DESYNC_DEBUG[0].c_str(), "1", 1) == 0);
// We cannot capture the exception thrown in watchdog thread without making
// lots of changes to the code. So we don't let the watchdog throw
// exception.
ASSERT_TRUE(
setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0);
options_ = c10d::ProcessGroupNCCL::Options::create();
// Set a super short watchdog timeout.
options_->timeout = std::chrono::milliseconds(100);
}
void watchdogTimeoutTestCommon(
ProcessGroupNCCLNoHeartbeatCaught& pg,
int multiplier) {
pg.forceSetDesyncDebugFlag();
pg.setTimedoutError();
auto work = pg.allreduce(tensors_);
std::this_thread::sleep_for(
std::chrono::seconds(heartBeatIntervalInSec * multiplier));
EXPECT_THROW(work->wait(), c10::DistBackendError);
}
const int heartBeatIntervalInSec = 2;
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> options_;
};
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoFinished) {
ProcessGroupNCCLNoHeartbeatCaught pg(store_, 0, 1, options_);
// Write debug info will lead to watchdog thread to wait for 30 seconds.
// And this is hard to override, so we just call it before hand. Otherwise,
// we need to set a long heartbeat timeout which will make the test way
// slower.
pg.forceTryWriteDebugInfo();
watchdogTimeoutTestCommon(pg, 2);
// The flag is true shows that the heartbeat monitor thread does not kill
// the watchdog thread when it is getting debug info such as desync debug
// info.
EXPECT_TRUE(pg.getWatchDogDebugInfoFinishedFlag());
// The flag is false shows that the heartbeat monitor thread does not
// trigger process abort if getting debug info and destroy PG is fast.
EXPECT_FALSE(pg.getErrorCaughtFlag());
// Communicators might be aborted here, further operations would fail.
}
TEST_F(ProcessGroupNCCLWatchdogTimeoutTest, testNCCLTimedoutDebugInfoStuck) {
ProcessGroupNCCLDebugInfoStuck pg(store_, 0, 1, options_);
// Need to keep main thread sleep longer so that we can let heartbeat monitor
// thread to finish the extra wait and flip the flag.
watchdogTimeoutTestCommon(pg, 4);
// The flag is false shows that we get stuck in getting debug info such as
// desync debug info in the watchdog thread.
EXPECT_FALSE(pg.getWatchDogDebugInfoFinishedFlag());
// The flag is true shows that the heartbeat monitor thread does trigger
// process abort if getting debug info gets stuck.
EXPECT_TRUE(pg.getErrorCaughtFlag());
// Communicators might be aborted here, further operations would fail.
}