caffe2/plan_executor: propagate exceptions from reporter substeps (#46424)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46424
Currently if an exception occurs in a reporter thread the process is killed via std::terminate. This adds support for handling the reporter exception if FLAGS_caffe2_handle_executor_threads_exceptions is set to true.
Test Plan: buck test mode/opt -c python.package_style=inplace //caffe2/caffe2/python:hypothesis_test //caffe2/caffe2:caffe2_test_cpu -- --stress-runs 100
Reviewed By: dahsh
Differential Revision: D24345027
fbshipit-source-id: 0659495c9e27680ebae41fe5a3cf26ce2f455cb3
diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc
index 7802cb0..97c309c 100644
--- a/caffe2/core/plan_executor.cc
+++ b/caffe2/core/plan_executor.cc
@@ -68,7 +68,8 @@
// correctly grab it on exit.
class ExceptionWrapperTerminate {
public:
- explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew) : ew_(std::move(ew)) {}
+ explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew)
+ : ew_(std::move(ew)) {}
~ExceptionWrapperTerminate() {
ew_.rethrowException();
@@ -100,44 +101,6 @@
using NetDefMap = std::unordered_map<std::string, NetDefInfo>;
-struct Reporter {
- struct ReporterInstance {
- std::mutex report_mutex;
- std::condition_variable report_cv;
- std::thread report_thread;
- ReporterInstance(int intervalMillis, bool* done, std::function<void()> f) {
- auto interval = std::chrono::milliseconds(intervalMillis);
- auto reportWorker = [=]() {
- std::unique_lock<std::mutex> lk(report_mutex);
- do {
- report_cv.wait_for(lk, interval, [&]() { return *done; });
- f();
- } while (!*done);
- };
- report_thread = std::thread(reportWorker);
- }
- };
-
- void start(int64_t intervalMillis, std::function<void()> f) {
- instances_.emplace_back(new ReporterInstance(intervalMillis, &done, f));
- }
-
- ~Reporter() {
- done = true;
- for (auto& instance : instances_) {
- if (!instance->report_thread.joinable()) {
- continue;
- }
- instance->report_cv.notify_all();
- instance->report_thread.join();
- }
- }
-
- private:
- std::vector<std::unique_ptr<ReporterInstance>> instances_;
- bool done{false};
-};
-
// Returns a function that returns `true` if we should continue
// iterating, given the current iteration count.
std::function<bool(int64_t)> getContinuationTest(
@@ -394,6 +357,23 @@
};
}
+ void Fail(const std::exception& ex) {
+ {
+ std::lock_guard<std::mutex> guard(exception_mutex_);
+ if (!first_exception_) {
+ LOG(ERROR) << "Substep exception:\n" << c10::GetExceptionString(ex);
+ first_exception_ = ExceptionWrapper(ex);
+ }
+ gotFailure = true;
+ }
+ Cancel();
+ }
+
+ ExceptionWrapper FirstException() {
+ std::lock_guard<std::mutex> guard(exception_mutex_);
+ return first_exception_;
+ }
+
// Cancel attempts to cancel the running nets in a best effort way. If the net
// or op type does IO and doesn't implement cancellation it may not be
// possible to cancel leading to execution getting stuck on error.
@@ -426,6 +406,9 @@
private:
std::unique_ptr<Workspace> localWorkspace_;
+
+ std::mutex exception_mutex_; // protects first_exception_
+ ExceptionWrapper first_exception_;
};
void ExecutionStepWrapper::Cancel() {
@@ -443,6 +426,65 @@
ws_id_injector_));
}
+struct Reporter {
+ struct ReporterInstance {
+ std::mutex report_mutex;
+ std::condition_variable report_cv;
+ std::thread report_thread;
+ ExceptionWrapper exception;
+
+ ReporterInstance(
+ int intervalMillis,
+ std::atomic<bool>* done,
+ std::function<void()> f,
+ ExecutionStepWrapper::CompiledGuard* compiledStep) {
+ auto interval = std::chrono::milliseconds(intervalMillis);
+ auto reportWorker = [=]() {
+ std::unique_lock<std::mutex> lk(report_mutex);
+ do {
+ report_cv.wait_for(lk, interval, [&]() { return done->load(); });
+ try {
+ f();
+ } catch (const std::exception& ex) {
+ LOG(ERROR) << "Reporter instance exception:\n"
+ << c10::GetExceptionString(ex);
+ if (!FLAGS_caffe2_handle_executor_threads_exceptions) {
+ throw;
+ }
+ (*compiledStep)->Fail(ex);
+ done->store(true);
+ }
+ } while (!done->load());
+ };
+ report_thread = std::thread(reportWorker);
+ }
+ };
+
+ explicit Reporter(ExecutionStepWrapper::CompiledGuard* compiledStep)
+ : compiledStep_(compiledStep) {}
+
+ void start(int64_t intervalMillis, std::function<void()> f) {
+ instances_.emplace_back(
+ new ReporterInstance(intervalMillis, &done_, f, compiledStep_));
+ }
+
+ ~Reporter() {
+ done_ = true;
+ for (auto& instance : instances_) {
+ if (!instance->report_thread.joinable()) {
+ continue;
+ }
+ instance->report_cv.notify_all();
+ instance->report_thread.join();
+ }
+ }
+
+ private:
+ std::vector<std::unique_ptr<ReporterInstance>> instances_;
+ std::atomic<bool> done_{false};
+ ExecutionStepWrapper::CompiledGuard* compiledStep_;
+};
+
#define CHECK_SHOULD_STOP(step, shouldStop) \
if (getShouldStop(shouldStop)) { \
VLOG(1) << "Execution step " << step.name() << " stopped by " \
@@ -458,7 +500,7 @@
std::unique_ptr<Reporter> reporter;
if (step.has_report_net() || compiledStep->reportSubsteps.size() > 0) {
- reporter = std::make_unique<Reporter>();
+ reporter = std::make_unique<Reporter>(&compiledStep);
auto* reportNet = compiledStep->reportNet;
if (reportNet) {
VLOG(1) << "Starting reporter net";
@@ -500,9 +542,8 @@
std::atomic<int> next_substep{0};
std::condition_variable cv;
- std::mutex exception_mutex; // exception_mutex protects done and first_exception
+ std::mutex exception_mutex; // protects done
int done{0};
- ExceptionWrapper first_exception;
auto worker = [&]() {
ScopeExitGuard on_exit([&] {
std::lock_guard<std::mutex> guard(exception_mutex);
@@ -521,14 +562,7 @@
compiledStep->gotFailure = true;
}
} catch (const std::exception& ex) {
- std::lock_guard<std::mutex> guard(exception_mutex);
- if (!first_exception) {
- first_exception = ExceptionWrapper(ex);
- LOG(ERROR) << "Parallel worker exception:\n"
- << c10::GetExceptionString(ex);
- }
- compiledStep->gotFailure = true;
- compiledStep->Cancel();
+ compiledStep->Fail(ex);
if (!FLAGS_caffe2_handle_executor_threads_exceptions) {
// In complex plans other threads might get stuck if another
// one fails. So we let exception to go out of thread which
@@ -554,11 +588,13 @@
// If we get an exception, try to wait for all threads to stop
// gracefully.
- cv.wait(guard, [&] { return workersDone() || first_exception; });
+ cv.wait(
+ guard, [&] { return workersDone() || compiledStep->gotFailure; });
cv.wait_for(
guard,
std::chrono::seconds(FLAGS_caffe2_plan_executor_exception_timeout),
[&] { return workersDone(); });
+ auto first_exception = compiledStep->FirstException();
if (!workersDone() && first_exception) {
LOG(ERROR) << "failed to stop concurrent workers after exception: "
<< first_exception.what();
@@ -592,7 +628,11 @@
}
}
}
- return true;
+
+ if (auto first_exception = compiledStep->FirstException()) {
+ first_exception.rethrowException();
+ }
+ return !compiledStep->gotFailure;
}
#undef CHECK_SHOULD_STOP
diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc
index 45df7fc..39288ee 100644
--- a/caffe2/core/plan_executor_test.cc
+++ b/caffe2/core/plan_executor_test.cc
@@ -185,6 +185,41 @@
return plan_def;
}
+PlanDef reporterErrorPlanWithCancellableStuckNet() {
+ // Set a plan with a concurrent net and a reporter net: one stuck net with
+ // blocking operator that never returns; one reporter net with error op
+ // that throws.
+ PlanDef plan_def;
+
+ auto* stuck_blocking_net = plan_def.add_network();
+ stuck_blocking_net->set_name("stuck_blocking_net");
+ {
+ auto* op = stuck_blocking_net->add_op();
+ op->set_type("StuckBlocking");
+ }
+
+ auto* error_net = plan_def.add_network();
+ error_net->set_name("error_net");
+ {
+ auto* op = error_net->add_op();
+ op->set_type("Error");
+ }
+
+ auto* execution_step = plan_def.add_execution_step();
+ execution_step->set_concurrent_substeps(true);
+ {
+ auto* substep = execution_step->add_substep();
+ substep->add_network(stuck_blocking_net->name());
+ }
+ {
+ auto* substep = execution_step->add_substep();
+ substep->set_run_every_ms(1);
+ substep->add_network(error_net->name());
+ }
+
+ return plan_def;
+}
+
struct HandleExecutorThreadExceptionsGuard {
HandleExecutorThreadExceptionsGuard(int timeout = 60) {
globalInit({
@@ -280,6 +315,17 @@
ASSERT_EQ(cancelCount, 1);
}
+TEST(PlanExecutorTest, ReporterErrorPlanWithCancellableStuckNet) {
+ HandleExecutorThreadExceptionsGuard guard;
+
+ cancelCount = 0;
+ PlanDef plan_def = reporterErrorPlanWithCancellableStuckNet();
+ Workspace ws;
+
+ ASSERT_THROW(ws.RunPlan(plan_def), TestError);
+ ASSERT_EQ(cancelCount, 1);
+}
+
} // namespace caffe2
#endif