Deprecate prof_dag (#11956)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11956

Deprecate prof_dag and redirect it to the unified executor

Reviewed By: aazzolini

Differential Revision: D9983992

fbshipit-source-id: 16821628a99a5683dc39cbb345ddab56e9d8721c
diff --git a/caffe2/contrib/prof/prof_dag_stats_op.cc b/caffe2/contrib/prof/prof_dag_stats_op.cc
index 70f4c73..684f764 100644
--- a/caffe2/contrib/prof/prof_dag_stats_op.cc
+++ b/caffe2/contrib/prof/prof_dag_stats_op.cc
@@ -18,9 +18,9 @@
         "op will be calculated separately")
     .Arg(
         "partial_net_name",
-        "(string) default to empty; describes the partial name of the ProfDAGNet")
+        "(string) default to empty; describes the partial name of the net")
     .Arg(
         "net_name",
-        "(string) default to empty; describes the name of the ProfDAGNet");
+        "(string) default to empty; describes the name of the net");
 } // namespace
 } // namespace caffe2
diff --git a/caffe2/contrib/prof/prof_dag_stats_op.h b/caffe2/contrib/prof/prof_dag_stats_op.h
index 1b90781..578ec44 100644
--- a/caffe2/contrib/prof/prof_dag_stats_op.h
+++ b/caffe2/contrib/prof/prof_dag_stats_op.h
@@ -3,6 +3,7 @@
 
 #include "caffe2/contrib/prof/prof_dag_net.h"
 #include "caffe2/core/context.h"
+#include "caffe2/core/net_async_base.h"
 #include "caffe2/core/operator.h"
 #include "caffe2/utils/math.h"
 
@@ -53,14 +54,16 @@
           " as part of its name");
     }
 
-    auto prof_dag_net = dynamic_cast_if_rtti<ProfDAGNet*>(net);
-    CAFFE_ENFORCE(prof_dag_net);
-
     ProfDAGProtos stats;
-    if (per_op_) {
-      stats = prof_dag_net->GetPerOperatorCost();
+    auto async_net = dynamic_cast_if_rtti<AsyncNetBase*>(net);
+    if (async_net) {
+      LOG(INFO) << "Using AsyncNetBase to collect stats";
+      stats = getProtos(async_net);
     } else {
-      stats = prof_dag_net->GetOperatorStats();
+      auto prof_dag_net = dynamic_cast_if_rtti<ProfDAGNet*>(net);
+      CAFFE_ENFORCE(prof_dag_net);
+      LOG(INFO) << "Using ProfDAGNet to collect stats";
+      stats = getProtos(prof_dag_net);
     }
 
     // Write protobuf message to the output blob
@@ -72,6 +75,17 @@
     return true;
   }
 
+  template <typename Net>
+  ProfDAGProtos getProtos(Net* net) {
+    ProfDAGProtos stats;
+    if (per_op_) {
+      stats = net->GetPerOperatorCost();
+    } else {
+      stats = net->GetOperatorStats();
+    }
+    return stats;
+  }
+
  protected:
   std::string net_name_;
   std::string partial_net_name_;
diff --git a/caffe2/core/net.cc b/caffe2/core/net.cc
index 37f38ce..f51d336 100644
--- a/caffe2/core/net.cc
+++ b/caffe2/core/net.cc
@@ -107,6 +107,7 @@
 const std::unordered_map<std::string, std::string>& defaultOverrides() {
   static const std::unordered_map<std::string, std::string> overrides = {
       {"dag", "async_scheduling"},
+      {"prof_dag", "async_scheduling"},
       {"async_dag", "async_scheduling"},
       {"async_polling", "async_scheduling"},
       {"async_simple", "simple"},
diff --git a/caffe2/core/net_async_base.cc b/caffe2/core/net_async_base.cc
index ba5ca01..c2728c7af 100644
--- a/caffe2/core/net_async_base.cc
+++ b/caffe2/core/net_async_base.cc
@@ -68,7 +68,9 @@
 AsyncNetBase::AsyncNetBase(
     const std::shared_ptr<const NetDef>& net_def,
     Workspace* ws)
-    : NetBase(net_def, ws) {
+    : NetBase(net_def, ws), counters_(net_def) {
+  computeExecutionModeFlags();
+
   operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
   helper_ = caffe2::make_unique<AsyncNetExecutorHelper>(this);
   operators_.reserve(operator_nodes_.size());
@@ -93,12 +95,15 @@
   for (const auto& chain : chains_) {
     const auto& last_op = operators_[chain.back()];
     events_.push_back(&last_op->event());
-    for (const auto& op_id : chain) {
-      if (op_id == chain.back() || op_id == chain.front()) {
-        continue;
+    // keep events for inner chain ops in case of profiling
+    if (!report_stats_) {
+      for (const auto& op_id : chain) {
+        if (op_id == chain.back() || op_id == chain.front()) {
+          continue;
+        }
+        const auto& op = operators_[op_id];
+        op->DisableEvent();
       }
-      const auto& op = operators_[op_id];
-      op->DisableEvent();
     }
   }
 
@@ -108,8 +113,6 @@
   if (tracer_) {
     LOG(INFO) << "Tracing net: " << net_def->name();
   }
-
-  computeExecutionModeFlags();
 }
 
 bool AsyncNetBase::handleRunError() {
@@ -294,14 +297,28 @@
   return chains_[task_id].size();
 }
 
+int AsyncNetBase::firstTaskOpId(int task_id) const {
+  return chains_[task_id].front();
+}
+
+int AsyncNetBase::lastTaskOpId(int task_id) const {
+  return chains_[task_id].back();
+}
+
 const OperatorBase* AsyncNetBase::firstTaskOp(int task_id) const {
-  auto op_id = chains_[task_id].front();
-  return operator_nodes_[op_id].operator_.get();
+  return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
 }
 
 const OperatorBase* AsyncNetBase::lastTaskOp(int task_id) const {
-  auto op_id = chains_[task_id].back();
-  return operator_nodes_[op_id].operator_.get();
+  return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
+}
+
+OperatorBase* AsyncNetBase::firstTaskOp(int task_id) {
+  return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
+}
+
+OperatorBase* AsyncNetBase::lastTaskOp(int task_id) {
+  return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
 }
 
 void AsyncNetBase::asyncWait(
@@ -364,14 +381,24 @@
     }
     for (auto& op_id : chains_[task_id]) {
       op = operators_[op_id];
-      TRACE_EVENT(
-          tracing::TRACE_OP,
-          op_id,
-          tracing::TRACE_TASK,
-          task_id,
-          tracing::TRACE_STREAM,
-          stream_id);
-      bool success = op->RunAsync(stream_id);
+      bool success = false;
+      if (!report_stats_) {
+        TRACE_EVENT(
+            tracing::TRACE_OP,
+            op_id,
+            tracing::TRACE_TASK,
+            task_id,
+            tracing::TRACE_STREAM,
+            stream_id);
+        success = op->RunAsync(stream_id);
+      } else {
+        counters_.AddPerOpStartTime(op_id);
+        success = op->RunAsync(stream_id);
+        if (success && op->device_option().device_type() != PROTO_CPU) {
+          op->Finish();
+        }
+        counters_.AddPerOpEndTime(op_id);
+      }
       if (!success) {
         auto err_msg = "Failed to execute an op: " +
             (op->has_debug_def() ? op->type() : " unknown");
@@ -425,7 +452,19 @@
   }
 }
 
-AsyncNetBase::~AsyncNetBase() {}
+ProfDAGProtos AsyncNetBase::GetOperatorStats() const {
+  return counters_.GetOperatorStats();
+}
+
+ProfDAGProtos AsyncNetBase::GetPerOperatorCost() const {
+  return counters_.GetPerOperatorCost();
+}
+
+AsyncNetBase::~AsyncNetBase() {
+  if (report_stats_) {
+    counters_.PrintStats();
+  }
+}
 
 C10_DEFINE_SHARED_REGISTRY(
     ThreadPoolRegistry,
@@ -459,6 +498,7 @@
     use_single_pool_ = true;
     use_per_net_pools_ = true;
     is_blocking_ = true;
+    report_stats_ = (net_type == kProfDag);
   } else if (net_type == kAsyncDag) {
     streams_per_gpu_ = 1;
     finish_chain_ = false;
@@ -467,6 +507,7 @@
     use_single_pool_ = true;
     use_per_net_pools_ = true;
     is_blocking_ = true;
+    report_stats_ = false;
   } else {
     streams_per_gpu_ = c10::FLAGS_caffe2_streams_per_gpu;
     finish_chain_ = c10::FLAGS_caffe2_net_async_finish_chain;
@@ -475,6 +516,16 @@
     use_single_pool_ = c10::FLAGS_caffe2_net_async_use_single_pool;
     use_per_net_pools_ = c10::FLAGS_caffe2_net_async_use_per_net_pools;
     is_blocking_ = false;
+    report_stats_ = false;
+  }
+
+  for (int arg_idx = 0; arg_idx < net_def_->arg_size(); ++arg_idx) {
+    auto& arg = net_def_->arg(arg_idx);
+    if (arg.has_name() && arg.name() == "enable_profiling") {
+      CAFFE_ENFORCE(arg.has_i(), "enable_profiling should be an int");
+      report_stats_ = arg.i() == 1;
+      break;
+    }
   }
 }
 
diff --git a/caffe2/core/net_async_base.h b/caffe2/core/net_async_base.h
index f8d61f7..8dccaf4 100644
--- a/caffe2/core/net_async_base.h
+++ b/caffe2/core/net_async_base.h
@@ -6,10 +6,12 @@
 #include "caffe2/core/net.h"
 #include "caffe2/core/net_async_base.h"
 #include "caffe2/core/net_dag_utils.h"
+#include "caffe2/core/prof_dag_counters.h"
 #include "caffe2/core/stats.h"
 #include "caffe2/core/timer.h"
 #include "caffe2/core/workspace.h"
 #include "caffe2/proto/caffe2_pb.h"
+#include "caffe2/proto/prof_dag.pb.h"
 #include "caffe2/utils/proto_utils.h"
 #include "caffe2/utils/thread_pool.h"
 
@@ -50,6 +52,9 @@
     return execution_chains_;
   }
 
+  ProfDAGProtos GetOperatorStats() const;
+  ProfDAGProtos GetPerOperatorCost() const;
+
  protected:
   bool canSchedule(
       int chain_id,
@@ -66,8 +71,13 @@
   int getParentCount(int child_id);
   bool testAndSetScheduled(int task_id);
   int numOps(int task_id) const;
+
+  int firstTaskOpId(int task_id) const;
+  int lastTaskOpId(int task_id) const;
   const OperatorBase* firstTaskOp(int task_id) const;
   const OperatorBase* lastTaskOp(int task_id) const;
+  OperatorBase* firstTaskOp(int task_id);
+  OperatorBase* lastTaskOp(int task_id);
 
   void asyncWait(
       int task_id,
@@ -126,6 +136,9 @@
   bool use_single_pool_;
   bool use_per_net_pools_;
   bool is_blocking_;
+  bool report_stats_;
+
+  ProfDAGCounters counters_;
 
   C10_DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
 
diff --git a/caffe2/core/net_async_scheduling.cc b/caffe2/core/net_async_scheduling.cc
index 369a8e7..1db51f2 100644
--- a/caffe2/core/net_async_scheduling.cc
+++ b/caffe2/core/net_async_scheduling.cc
@@ -61,6 +61,16 @@
       }
     }
 
+    if (report_stats_) {
+      auto last_op_id = lastTaskOpId(task_id);
+      auto* last_op = lastTaskOp(task_id);
+      if (last_op->device_option().device_type() == PROTO_CPU &&
+          last_op->HasAsyncPart()) {
+        last_op->event().SetCallback(
+            [this, last_op_id] { counters_.AddPerOpAsyncEndTime(last_op_id); });
+      }
+    }
+
     for (auto child_id : children(task_id)) {
       int parent_count = updateParentCount(child_id);
       if (parent_count == 0) {
@@ -209,6 +219,9 @@
   std::unique_lock<std::mutex> lock(running_mutex_);
   // wait for scheduled ops and make sure all events are marked as finished
   finalizeEvents();
+  if (report_stats_) {
+    counters_.ReportRunEnd();
+  }
   // notify observers and waiters
   StopAllObservers();
   running_ = false;
@@ -228,7 +241,11 @@
 
       StartAllObservers();
       tracing::startIter(tracer_);
+      if (report_stats_) {
+        counters_.ReportRunStart();
+      }
     }
+
     for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
       if (parents(task_id).empty()) {
         schedule(task_id);
diff --git a/caffe2/core/prof_dag_counters.cc b/caffe2/core/prof_dag_counters.cc
new file mode 100644
index 0000000..f256880
--- /dev/null
+++ b/caffe2/core/prof_dag_counters.cc
@@ -0,0 +1,156 @@
+#include "caffe2/core/prof_dag_counters.h"
+
+#include <ostream>
+#include <sstream>
+
+namespace caffe2 {
+
+ProfDAGCounters::ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def)
+    : net_name_(net_def->name()), num_runs_(0) {
+  op_types_.reserve(net_def->op_size());
+  for (auto op_id = 0; op_id < net_def->op_size(); ++op_id) {
+    op_types_.push_back(net_def->op(op_id).type());
+  }
+  time_per_op_total_.resize(op_types_.size());
+}
+
+void ProfDAGCounters::ReportRunStart() {
+  num_runs_ += 1;
+  timer_.Start();
+
+  op_start_times_run_.clear();
+  op_start_times_run_.resize(op_types_.size(), -1.0);
+  op_end_times_run_.clear();
+  op_end_times_run_.resize(op_types_.size(), -1.0);
+  op_async_end_times_run_.clear();
+  op_async_end_times_run_.resize(op_types_.size(), -1.0);
+}
+
+void ProfDAGCounters::AddPerOpStartTime(size_t op_id) {
+  if (num_runs_ <= 1) {
+    return;
+  }
+
+  CAFFE_ENFORCE(op_id >= 0 && op_id < op_start_times_run_.size());
+  op_start_times_run_[op_id] = timer_.MilliSeconds();
+}
+
+void ProfDAGCounters::AddPerOpEndTime(size_t op_id) {
+  if (num_runs_ <= 1) {
+    return;
+  }
+
+  CAFFE_ENFORCE(op_id >= 0 && op_id < op_end_times_run_.size());
+  op_end_times_run_[op_id] = timer_.MilliSeconds();
+}
+
+void ProfDAGCounters::AddPerOpAsyncEndTime(size_t op_id) {
+  if (num_runs_ <= 1) {
+    return;
+  }
+
+  CAFFE_ENFORCE(op_id >= 0 && op_id < op_async_end_times_run_.size());
+  op_async_end_times_run_[op_id] = timer_.MilliSeconds();
+}
+
+void ProfDAGCounters::ReportRunEnd() {
+  if (num_runs_ <= 1) {
+    return;
+  }
+
+  auto runtime = timer_.MilliSeconds();
+  runtime_stats_ += ProfDAGStats(runtime);
+
+  CaffeMap<std::string, float> cum_per_type_time_run_;
+  CaffeMap<std::string, float> cum_per_type_invocations_run_;
+  for (auto op_id = 0; op_id < op_types_.size(); ++op_id) {
+    float op_time;
+    CAFFE_ENFORCE(op_start_times_run_[op_id] > 0);
+    if (op_async_end_times_run_[op_id] > 0) {
+      auto op_async_time =
+          op_async_end_times_run_[op_id] - op_start_times_run_[op_id];
+      CAFFE_ENFORCE_GE(op_async_time, 0.0);
+      op_time = op_async_time;
+    } else {
+      auto op_sync_time = op_end_times_run_[op_id] - op_start_times_run_[op_id];
+      CAFFE_ENFORCE_GE(op_sync_time, 0.0);
+      op_time = op_sync_time;
+    }
+
+    time_per_op_total_[op_id] += ProfDAGStats(op_time);
+
+    const string& op_type = op_types_[op_id];
+    cum_per_type_time_run_[op_type] += op_time;
+    cum_per_type_invocations_run_[op_type] += 1;
+  }
+
+  for (const auto& kv : cum_per_type_time_run_) {
+    time_per_op_type_total_[kv.first] += ProfDAGStats(kv.second);
+    times_per_run_per_type_total_[kv.first] +=
+        ProfDAGStats(cum_per_type_invocations_run_[kv.first]);
+  }
+}
+
+ProfDAGProto ProfDAGCounters::statsProto(
+    const std::string& name,
+    const ProfDAGStats& stats) const {
+  ProfDAGProto stats_proto;
+  const auto& moments = stats.computeMoments();
+  stats_proto.set_mean(moments.first);
+  stats_proto.set_stddev(moments.second);
+  stats_proto.set_name(name);
+  return stats_proto;
+}
+
+ProfDAGProtos ProfDAGCounters::GetOperatorStats() const {
+  CAFFE_ENFORCE_GT(num_runs_, 1, "Insufficient number of runs");
+  ProfDAGProtos prof_dag_protos;
+  for (auto& item : time_per_op_type_total_) {
+    auto buf = prof_dag_protos.add_stats();
+    buf->CopyFrom(statsProto(item.first, item.second));
+  }
+  return prof_dag_protos;
+}
+
+ProfDAGProtos ProfDAGCounters::GetPerOperatorCost() const {
+  CAFFE_ENFORCE_GT(num_runs_, 1, "Insufficient number of runs");
+  ProfDAGProtos prof_dag_protos;
+  for (int op_id = 0; op_id < op_types_.size(); op_id++) {
+    const string& op_type = op_types_[op_id];
+    auto buf = prof_dag_protos.add_stats();
+    std::string op_output_name =
+        net_name_ + "___" + to_string(op_id) + "___" + op_type;
+    buf->CopyFrom(statsProto(op_output_name, time_per_op_total_[op_id]));
+  }
+  return prof_dag_protos;
+}
+
+void ProfDAGCounters::PrintStats() {
+  if (num_runs_ <= 1) {
+    LOG(INFO) << "Insufficient number of runs";
+    return;
+  }
+
+  std::ostringstream debug_out;
+  debug_out << "Measured operators over " << num_runs_ << " net runs ("
+            << net_name_ << "), #ops: " << op_types_.size() << std::endl;
+
+  debug_out << "Mean time in operator type per run (stddev):" << std::endl;
+  for (const auto& item : time_per_op_type_total_) {
+    const auto& moments = item.second.computeMoments();
+    const auto& times_moments =
+        times_per_run_per_type_total_[item.first].computeMoments();
+    debug_out << std::setw(10) << std::setfill(' ') << moments.first
+              << " ms/run (" << std::setw(10) << std::setfill(' ')
+              << moments.second << " ms/run) "
+              << " Op count per run: " << times_moments.first << "  "
+              << item.first << std::endl;
+  }
+  const auto& runtime_moments = runtime_stats_.computeMoments();
+  debug_out << net_name_ << " runtime: " << runtime_moments.first << " ms ("
+            << runtime_moments.second << " ms)" << std::endl;
+
+  LOG(INFO) << debug_out.str();
+}
+
+} // namespace caffe2
diff --git a/caffe2/core/prof_dag_counters.h b/caffe2/core/prof_dag_counters.h
new file mode 100644
index 0000000..49d7305
--- /dev/null
+++ b/caffe2/core/prof_dag_counters.h
@@ -0,0 +1,104 @@
+#ifndef PROF_DAG_COUNTERS_H
+#define PROF_DAG_COUNTERS_H
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/timer.h"
+#include "caffe2/proto/caffe2_pb.h"
+#include "caffe2/proto/prof_dag.pb.h"
+
+#include <unordered_map>
+
+namespace caffe2 {
+
+class ProfDAGStats {
+ public:
+  ProfDAGStats() : sum_(0.0), sqrsum_(0.0), cnt_(0) {}
+  explicit ProfDAGStats(float time_ms)
+      : sum_(time_ms), sqrsum_(time_ms * time_ms), cnt_(1) {}
+
+  ProfDAGStats& operator+=(const ProfDAGStats& rhs) {
+    sum_ += rhs.sum_;
+    sqrsum_ += rhs.sqrsum_;
+    cnt_ += rhs.cnt_;
+    return *this;
+  }
+
+  std::pair<float, float> computeMoments() const {
+    CAFFE_ENFORCE_GT(cnt_, 0);
+    float mean = sum_ / cnt_;
+    float stddev = std::sqrt(sqrsum_ / cnt_ - mean * mean);
+    return {mean, stddev};
+  }
+
+  float sum() const {
+    return sum_;
+  }
+
+  float sqrsum() const {
+    return sqrsum_;
+  }
+
+  size_t cnt() const {
+    return cnt_;
+  }
+
+ private:
+  float sum_;
+  float sqrsum_;
+  size_t cnt_;
+};
+
+/**
+ * A simple wrapper around prof_dag's counters
+ */
+class ProfDAGCounters {
+ public:
+  explicit ProfDAGCounters(const std::shared_ptr<const NetDef>& net_def);
+
+  // Collects the execution time per each operator type
+  ProfDAGProtos GetOperatorStats() const;
+
+  // Collects the execution time of each operator, the output is
+  // formatted as a map: (netName__opIndex__opType, cost)
+  ProfDAGProtos GetPerOperatorCost() const;
+
+  // ReportRunStart/End are called at the beginning and at the end of
+  // each net's run
+  void ReportRunStart();
+  void ReportRunEnd();
+
+  void AddPerOpStartTime(size_t op_id);
+  void AddPerOpEndTime(size_t op_id);
+  void AddPerOpAsyncEndTime(size_t op_id);
+
+  void PrintStats();
+
+ private:
+  ProfDAGProto statsProto(const std::string& name, const ProfDAGStats& stats)
+      const;
+
+  std::vector<std::string> op_types_;
+
+  // Cumulative stats per operator instance of the net
+  std::vector<ProfDAGStats> time_per_op_total_;
+
+  // Cumulative stats per unique operator type
+  CaffeMap<std::string, ProfDAGStats> time_per_op_type_total_;
+
+  CaffeMap<std::string, ProfDAGStats> times_per_run_per_type_total_;
+
+  std::string net_name_;
+
+  int num_runs_;
+  Timer timer_;
+  ProfDAGStats runtime_stats_;
+
+  std::vector<float> op_start_times_run_;
+  std::vector<float> op_end_times_run_;
+  std::vector<float> op_async_end_times_run_;
+};
+
+} // namespace caffe2
+
+#endif