caffe2/plan_executor: propagate exceptions from reporter substeps (#46424)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46424

Currently if an exception occurs in a reporter thread the process is killed via std::terminate. This adds support for handling the reporter exception if FLAGS_caffe2_handle_executor_threads_exceptions is set to true.

Test Plan: buck test mode/opt -c python.package_style=inplace //caffe2/caffe2/python:hypothesis_test //caffe2/caffe2:caffe2_test_cpu -- --stress-runs 100

Reviewed By: dahsh

Differential Revision: D24345027

fbshipit-source-id: 0659495c9e27680ebae41fe5a3cf26ce2f455cb3
diff --git a/caffe2/core/plan_executor.cc b/caffe2/core/plan_executor.cc
index 7802cb0..97c309c 100644
--- a/caffe2/core/plan_executor.cc
+++ b/caffe2/core/plan_executor.cc
@@ -68,7 +68,8 @@
 // correctly grab it on exit.
 class ExceptionWrapperTerminate {
  public:
-  explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew) : ew_(std::move(ew)) {}
+  explicit ExceptionWrapperTerminate(ExceptionWrapper&& ew)
+      : ew_(std::move(ew)) {}
 
   ~ExceptionWrapperTerminate() {
     ew_.rethrowException();
@@ -100,44 +101,6 @@
 
 using NetDefMap = std::unordered_map<std::string, NetDefInfo>;
 
-struct Reporter {
-  struct ReporterInstance {
-    std::mutex report_mutex;
-    std::condition_variable report_cv;
-    std::thread report_thread;
-    ReporterInstance(int intervalMillis, bool* done, std::function<void()> f) {
-      auto interval = std::chrono::milliseconds(intervalMillis);
-      auto reportWorker = [=]() {
-        std::unique_lock<std::mutex> lk(report_mutex);
-        do {
-          report_cv.wait_for(lk, interval, [&]() { return *done; });
-          f();
-        } while (!*done);
-      };
-      report_thread = std::thread(reportWorker);
-    }
-  };
-
-  void start(int64_t intervalMillis, std::function<void()> f) {
-    instances_.emplace_back(new ReporterInstance(intervalMillis, &done, f));
-  }
-
-  ~Reporter() {
-    done = true;
-    for (auto& instance : instances_) {
-      if (!instance->report_thread.joinable()) {
-        continue;
-      }
-      instance->report_cv.notify_all();
-      instance->report_thread.join();
-    }
-  }
-
- private:
-  std::vector<std::unique_ptr<ReporterInstance>> instances_;
-  bool done{false};
-};
-
 // Returns a function that returns `true` if we should continue
 // iterating, given the current iteration count.
 std::function<bool(int64_t)> getContinuationTest(
@@ -394,6 +357,23 @@
     };
   }
 
+  void Fail(const std::exception& ex) {
+    {
+      std::lock_guard<std::mutex> guard(exception_mutex_);
+      if (!first_exception_) {
+        LOG(ERROR) << "Substep exception:\n" << c10::GetExceptionString(ex);
+        first_exception_ = ExceptionWrapper(ex);
+      }
+      gotFailure = true;
+    }
+    Cancel();
+  }
+
+  ExceptionWrapper FirstException() {
+    std::lock_guard<std::mutex> guard(exception_mutex_);
+    return first_exception_;
+  }
+
   // Cancel attempts to cancel the running nets in a best effort way. If the net
   // or op type does IO and doesn't implement cancellation it may not be
   // possible to cancel leading to execution getting stuck on error.
@@ -426,6 +406,9 @@
 
  private:
   std::unique_ptr<Workspace> localWorkspace_;
+
+  std::mutex exception_mutex_; // protects first_exception_
+  ExceptionWrapper first_exception_;
 };
 
 void ExecutionStepWrapper::Cancel() {
@@ -443,6 +426,65 @@
       ws_id_injector_));
 }
 
+struct Reporter {
+  struct ReporterInstance {
+    std::mutex report_mutex;
+    std::condition_variable report_cv;
+    std::thread report_thread;
+    ExceptionWrapper exception;
+
+    ReporterInstance(
+        int intervalMillis,
+        std::atomic<bool>* done,
+        std::function<void()> f,
+        ExecutionStepWrapper::CompiledGuard* compiledStep) {
+      auto interval = std::chrono::milliseconds(intervalMillis);
+      auto reportWorker = [=]() {
+        std::unique_lock<std::mutex> lk(report_mutex);
+        do {
+          report_cv.wait_for(lk, interval, [&]() { return done->load(); });
+          try {
+            f();
+          } catch (const std::exception& ex) {
+            LOG(ERROR) << "Reporter instance exception:\n"
+                       << c10::GetExceptionString(ex);
+            if (!FLAGS_caffe2_handle_executor_threads_exceptions) {
+              throw;
+            }
+            (*compiledStep)->Fail(ex);
+            done->store(true);
+          }
+        } while (!done->load());
+      };
+      report_thread = std::thread(reportWorker);
+    }
+  };
+
+  explicit Reporter(ExecutionStepWrapper::CompiledGuard* compiledStep)
+      : compiledStep_(compiledStep) {}
+
+  void start(int64_t intervalMillis, std::function<void()> f) {
+    instances_.emplace_back(
+        new ReporterInstance(intervalMillis, &done_, f, compiledStep_));
+  }
+
+  ~Reporter() {
+    done_ = true;
+    for (auto& instance : instances_) {
+      if (!instance->report_thread.joinable()) {
+        continue;
+      }
+      instance->report_cv.notify_all();
+      instance->report_thread.join();
+    }
+  }
+
+ private:
+  std::vector<std::unique_ptr<ReporterInstance>> instances_;
+  std::atomic<bool> done_{false};
+  ExecutionStepWrapper::CompiledGuard* compiledStep_;
+};
+
 #define CHECK_SHOULD_STOP(step, shouldStop)                       \
   if (getShouldStop(shouldStop)) {                                \
     VLOG(1) << "Execution step " << step.name() << " stopped by " \
@@ -458,7 +500,7 @@
 
   std::unique_ptr<Reporter> reporter;
   if (step.has_report_net() || compiledStep->reportSubsteps.size() > 0) {
-    reporter = std::make_unique<Reporter>();
+    reporter = std::make_unique<Reporter>(&compiledStep);
     auto* reportNet = compiledStep->reportNet;
     if (reportNet) {
       VLOG(1) << "Starting reporter net";
@@ -500,9 +542,8 @@
 
         std::atomic<int> next_substep{0};
         std::condition_variable cv;
-        std::mutex exception_mutex; // exception_mutex protects done and first_exception
+        std::mutex exception_mutex; // protects done
         int done{0};
-        ExceptionWrapper first_exception;
         auto worker = [&]() {
           ScopeExitGuard on_exit([&] {
             std::lock_guard<std::mutex> guard(exception_mutex);
@@ -521,14 +562,7 @@
               compiledStep->gotFailure = true;
             }
           } catch (const std::exception& ex) {
-            std::lock_guard<std::mutex> guard(exception_mutex);
-            if (!first_exception) {
-              first_exception = ExceptionWrapper(ex);
-              LOG(ERROR) << "Parallel worker exception:\n"
-                         << c10::GetExceptionString(ex);
-            }
-            compiledStep->gotFailure = true;
-            compiledStep->Cancel();
+            compiledStep->Fail(ex);
             if (!FLAGS_caffe2_handle_executor_threads_exceptions) {
               // In complex plans other threads might get stuck if another
               // one fails. So we let exception to go out of thread which
@@ -554,11 +588,13 @@
 
         // If we get an exception, try to wait for all threads to stop
         // gracefully.
-        cv.wait(guard, [&] { return workersDone() || first_exception; });
+        cv.wait(
+            guard, [&] { return workersDone() || compiledStep->gotFailure; });
         cv.wait_for(
             guard,
             std::chrono::seconds(FLAGS_caffe2_plan_executor_exception_timeout),
             [&] { return workersDone(); });
+        auto first_exception = compiledStep->FirstException();
         if (!workersDone() && first_exception) {
           LOG(ERROR) << "failed to stop concurrent workers after exception: "
                      << first_exception.what();
@@ -592,7 +628,11 @@
       }
     }
   }
-  return true;
+
+  if (auto first_exception = compiledStep->FirstException()) {
+    first_exception.rethrowException();
+  }
+  return !compiledStep->gotFailure;
 }
 
 #undef CHECK_SHOULD_STOP
diff --git a/caffe2/core/plan_executor_test.cc b/caffe2/core/plan_executor_test.cc
index 45df7fc..39288ee 100644
--- a/caffe2/core/plan_executor_test.cc
+++ b/caffe2/core/plan_executor_test.cc
@@ -185,6 +185,41 @@
   return plan_def;
 }
 
+PlanDef reporterErrorPlanWithCancellableStuckNet() {
+  // Set a plan with a concurrent net and a reporter net: one stuck net with
+  // blocking operator that never returns; one reporter net with error op
+  // that throws.
+  PlanDef plan_def;
+
+  auto* stuck_blocking_net = plan_def.add_network();
+  stuck_blocking_net->set_name("stuck_blocking_net");
+  {
+    auto* op = stuck_blocking_net->add_op();
+    op->set_type("StuckBlocking");
+  }
+
+  auto* error_net = plan_def.add_network();
+  error_net->set_name("error_net");
+  {
+    auto* op = error_net->add_op();
+    op->set_type("Error");
+  }
+
+  auto* execution_step = plan_def.add_execution_step();
+  execution_step->set_concurrent_substeps(true);
+  {
+    auto* substep = execution_step->add_substep();
+    substep->add_network(stuck_blocking_net->name());
+  }
+  {
+    auto* substep = execution_step->add_substep();
+    substep->set_run_every_ms(1);
+    substep->add_network(error_net->name());
+  }
+
+  return plan_def;
+}
+
 struct HandleExecutorThreadExceptionsGuard {
   HandleExecutorThreadExceptionsGuard(int timeout = 60) {
     globalInit({
@@ -280,6 +315,17 @@
   ASSERT_EQ(cancelCount, 1);
 }
 
+TEST(PlanExecutorTest, ReporterErrorPlanWithCancellableStuckNet) {
+  HandleExecutorThreadExceptionsGuard guard;
+
+  cancelCount = 0;
+  PlanDef plan_def = reporterErrorPlanWithCancellableStuckNet();
+  Workspace ws;
+
+  ASSERT_THROW(ws.RunPlan(plan_def), TestError);
+  ASSERT_EQ(cancelCount, 1);
+}
+
 } // namespace caffe2
 
 #endif